This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f1325133187 Fix Execution API refresh token (#58782)
f1325133187 is described below
commit f1325133187f830d465dcdf56ebf31a33bbaf340
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Fri Nov 28 16:35:13 2025 +0100
Fix Execution API refresh token (#58782)
---
.../src/airflow/api_fastapi/execution_api/app.py | 37 ++++++++++++-
.../src/airflow/api_fastapi/execution_api/deps.py | 62 +---------------------
.../api_fastapi/execution_api/routes/__init__.py | 4 +-
.../execution_api/versions/head/test_router.py | 8 ++-
.../versions/head/test_task_instances.py | 3 +-
5 files changed, 43 insertions(+), 71 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index 9e30ffb5af5..9d93f3bf84d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import json
+import time
from contextlib import AsyncExitStack
from functools import cached_property
from typing import TYPE_CHECKING, Any
@@ -126,6 +127,39 @@ class CorrelationIdMiddleware(BaseHTTPMiddleware):
return response
+class JWTReissueMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request: Request, call_next):
+ from airflow.configuration import conf
+
+ response: Response = await call_next(request)
+
+ refreshed_token: str | None = None
+ auth_header = request.headers.get("authorization")
+ if auth_header and auth_header.lower().startswith("bearer "):
+ token = auth_header.split(" ", 1)[1]
+ try:
+ async with svcs.Container(request.app.state.svcs_registry) as
services:
+ validator: JWTValidator = await services.aget(JWTValidator)
+ claims = await validator.avalidated_claims(token, {})
+
+ now = int(time.time())
+ validity = conf.getint("execution_api",
"jwt_expiration_time")
+ refresh_when_less_than = max(int(validity * 0.20), 30)
+ valid_left = int(claims.get("exp", 0)) - now
+ if valid_left <= refresh_when_less_than:
+ generator: JWTGenerator = await
services.aget(JWTGenerator)
+ refreshed_token = generator.generate(claims)
+ except Exception as err:
+ # Do not block the response if refreshing fails; log a warning
for visibility
+ logger.warning(
+ "JWT reissue middleware failed to refresh token",
error=str(err), exc_info=True
+ )
+
+ if refreshed_token:
+ response.headers["Refreshed-API-Token"] = refreshed_token
+ return response
+
+
class CadwynWithOpenAPICustomization(Cadwyn):
# Workaround lack of customzation
https://github.com/zmievsa/cadwyn/issues/255
async def openapi_jsons(self, req: Request) -> JSONResponse:
@@ -211,6 +245,7 @@ def create_task_execution_api_app() -> FastAPI:
# Add correlation-id middleware for request tracing
app.add_middleware(CorrelationIdMiddleware)
+ app.add_middleware(JWTReissueMiddleware)
app.generate_and_include_versioned_routers(execution_api_router)
@@ -266,7 +301,6 @@ class InProcessExecutionAPI:
from airflow.api_fastapi.execution_api.deps import (
JWTBearerDep,
JWTBearerTIPathDep,
- JWTRefresherDep,
)
from airflow.api_fastapi.execution_api.routes.connections import
has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import
has_variable_access
@@ -281,7 +315,6 @@ class InProcessExecutionAPI:
self._app.dependency_overrides[JWTBearerDep.dependency] =
always_allow
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] =
always_allow
- self._app.dependency_overrides[JWTRefresherDep.dependency] =
always_allow
self._app.dependency_overrides[has_connection_access] =
always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
index 2648a64ffad..d247a31f5f4 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
@@ -18,17 +18,14 @@
# Disable future annotations in this file to work around
https://github.com/fastapi/fastapi/issues/13056
# ruff: noqa: I002
-import sys
-import time
from typing import Any
import structlog
import svcs
-from fastapi import Depends, HTTPException, Request, Response, status
+from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer
-from starlette.exceptions import HTTPException as StarletteHTTPException
-from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
+from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
log = structlog.get_logger(logger_name=__name__)
@@ -98,58 +95,3 @@ JWTBearerDep: TIToken = Depends(JWTBearer())
# This checks that the UUID in the url matches the one in the token for us.
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
-
-
-class JWTReissuer:
- """Re-issue JWTs to requests when they are about to run out."""
-
- def __init__(self):
- from airflow.configuration import conf
-
- self.refresh_when_less_than = max(
- # Issue a new token to a task when the current one is valid for
only either 20% of the total validity,
- # or 30s
- int(conf.getint("execution_api", "jwt_expiration_time") * 0.20),
- 30,
- )
-
- async def __call__(
- self,
- response: Response,
- token=JWTBearerDep,
- services=DepContainer,
- ):
- try:
- yield
- finally:
- # We want to run this even in the case of 404 errors etc
- now = int(time.time())
-
- try:
- valid_left = token.claims["exp"] - now
- if valid_left <= self.refresh_when_less_than:
- generator: JWTGenerator = await services.aget(JWTGenerator)
- new = generator.generate(token.claims)
- response.headers["Refreshed-API-Token"] = new
- log.debug(
- "Refreshed token issued to Task",
- valid_left=valid_left,
- refresh_when_less_than=self.refresh_when_less_than,
- )
-
- exc, val, _ = sys.exc_info()
- if val and isinstance(val, StarletteHTTPException):
- # If there is an exception thrown, we need to set the
headers there instead
- if val.headers is None:
- val.headers = {}
-
- # Defined as a "mapping type", but 99.9% of the time
it's a mutable dict. We catch
- # errors if not
- val.headers["Refreshed-API-Token"] = new # type:
ignore[index]
-
- except Exception as e:
- # Don't 500 if there's a problem
- log.warning("Error refreshing Task JWT",
err=f"{type(e).__name__}: {e}")
-
-
-JWTRefresherDep = Depends(JWTReissuer())
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
index 89d96083876..562b8588fbf 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from cadwyn import VersionedAPIRouter
from fastapi import APIRouter
-from airflow.api_fastapi.execution_api.deps import JWTBearerDep,
JWTRefresherDep
+from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.routes import (
asset_events,
assets,
@@ -37,7 +37,7 @@ execution_api_router = APIRouter()
execution_api_router.include_router(health.router, prefix="/health",
tags=["Health"])
# _Every_ single endpoint under here must be authenticated. Some do further
checks on top of these
-authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep,
JWTRefresherDep]) # type: ignore[list-item]
+authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) #
type: ignore[list-item]
authenticated_router.include_router(assets.router, prefix="/assets",
tags=["Assets"])
authenticated_router.include_router(asset_events.router,
prefix="/asset-events", tags=["Asset Events"])
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
index e4aadcead6c..85f7df46915 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
@@ -24,7 +24,6 @@ from fastapi import FastAPI
from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
-from airflow.api_fastapi.execution_api.deps import JWTRefresherDep, JWTReissuer
from tests_common.test_utils.config import conf_vars
@@ -57,15 +56,14 @@ def test_expiring_token_is_reissued(
"exp": moment + validity,
}
- with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
- exec_app.dependency_overrides[JWTRefresherDep.dependency] =
JWTReissuer()
-
time_machine.move_to(moment + age, tick=False)
# Inject our fake JWTValidator object. Can be over-ridden by tests if they
want
lifespan.registry.register_value(JWTValidator, auth)
# In order to test this we need any endpoint to hit. The easiest one to
use is variable get
- response = client.get("/execution/variables/key1")
+
+ with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
+ response = client.get("/execution/variables/key1",
headers={"Authorization": "Bearer dummy"})
if expect_refreshed_token:
assert "Refreshed-API-Token" in response.headers
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index c410d4e4733..f91ae7f285f 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -98,7 +98,6 @@ def test_id_matches_sub_claim(client, session,
create_task_instance):
raise RuntimeError("Fake auth denied")
return claims
- # validator.avalidated_claims.side_effect = [{}, RuntimeError("fail for
tests"), claims, claims]
validator.avalidated_claims.side_effect = side_effect
lifespan.registry.register_value(JWTValidator, validator)
@@ -113,7 +112,7 @@ def test_id_matches_sub_claim(client, session,
create_task_instance):
resp =
client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run",
json=payload)
assert resp.status_code == 403
- validator.avalidated_claims.assert_called_with(
+ assert validator.avalidated_claims.call_args_list[1] == mock.call(
mock.ANY, {"sub": {"essential": True, "value":
"9c230b40-da03-451d-8bd7-be30471be383"}}
)
validator.avalidated_claims.reset_mock()