This is an automated email from the ASF dual-hosted git repository.
vatsrahul1001 pushed a commit to branch v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-2-test by this push:
new 70cacb116f0 Two-token mechanism for task execution to prevent token
expiration while tasks wait in executor queues (#60108) (#66989)
70cacb116f0 is described below
commit 70cacb116f0c5678e9428931d2ff23f6f1540862
Author: Rahul Vats <[email protected]>
AuthorDate: Mon May 18 10:43:44 2026 +0530
Two-token mechanism for task execution to prevent token expiration while
tasks wait in executor queues (#60108) (#66989)
Tasks waiting in executor queues (Celery, Kubernetes) can have their JWT
tokens expire before execution starts, causing auth failures on the Execution
API. This is a real problem in production, when queues back up or workers are
slow to pick up tasks, the original short-lived token expires and the worker
gets a 403 when it finally tries to start the task.
Fixes: #53713
Related: #59553
closes: https://github.com/apache/airflow/pull/62129
## Approach
Two-token mechanism: a workload token (lifetime tracks [scheduler]
task_queued_timeout) travels with the task through the queue, and a short-lived
execution token is issued when the task actually starts running.
The workload token carries a scope: "workload" claim and is restricted to
the /run endpoint only, enforced via FastAPI SecurityScopes and a custom
ExecutionAPIRoute. When /run succeeds, it returns an execution token via
Refreshed-API-Token header. The SDK client picks it up and uses it for all
subsequent API calls. The existing JWTReissueMiddleware handles refreshing
execution tokens near expiry and skips workload tokens.
Built on @ashb's SecurityScopes foundation.
### Security considerations
Even if a workload token is intercepted, it can only call /run which
already guards against running a task more than once (returns 409 if the task
isn't in QUEUED/RESTARTING state). All other endpoints reject workload tokens ,
they require execution scope. The execution token issued by /run is short-lived
and automatically refreshed, keeping the existing security posture for all API
calls during task execution.
(cherry picked from commit 2b6e8181e3ae2d4816a67f1be020c8effbfed440)
Co-authored-by: Anish Giri
<[email protected]>
---
.../docs/security/jwt_token_authentication.rst | 61 +++++---
.../src/airflow/api_fastapi/auth/tokens.py | 10 +-
.../src/airflow/api_fastapi/execution_api/app.py | 13 +-
.../execution_api/routes/task_instances.py | 20 ++-
.../src/airflow/config_templates/config.yml | 4 +
.../src/airflow/executors/workloads/base.py | 10 +-
.../tests/unit/api_fastapi/auth/test_tokens.py | 28 ++++
.../unit/api_fastapi/execution_api/conftest.py | 19 ++-
.../execution_api/versions/head/test_router.py | 6 +-
.../versions/head/test_task_instances.py | 169 +++++++++++++++------
.../tests/unit/executors/test_workloads.py | 25 ++-
.../src/tests_common/test_utils/mock_executor.py | 3 +-
12 files changed, 282 insertions(+), 86 deletions(-)
diff --git a/airflow-core/docs/security/jwt_token_authentication.rst
b/airflow-core/docs/security/jwt_token_authentication.rst
index 7aa85bba9a3..e823c9f787a 100644
--- a/airflow-core/docs/security/jwt_token_authentication.rst
+++ b/airflow-core/docs/security/jwt_token_authentication.rst
@@ -201,16 +201,25 @@ Token structure (Execution API)
Token scopes (Execution API)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-The Execution API defines two token scopes:
+The Execution API defines two token scopes with different lifetimes:
**workload**
- A restricted scope accepted only on endpoints that explicitly opt in via
- ``Security(require_auth, scopes=["token:workload"])``. Used for endpoints
that
- manage task state transitions.
+ A token embedded in the workload JSON payload when the Scheduler
+ dispatches a task. The longer lifetime
+ allows tasks to remain valid while waiting in executor queues before
execution
+ begins. When a worker calls the ``/run`` endpoint with a ``workload``
token, the
+ server issues a fresh ``execution``-scoped token in the
``Refreshed-API-Token``
+ response header. Lifetime equals ``[scheduler] task_queued_timeout``
(default
+ 600 seconds) — the same timeout the scheduler uses to reap queue-starved
tasks —
+ so tuning ``task_queued_timeout`` also widens the window a task can wait in
a
+ backed-up queue before its workload token expires.
**execution**
- Accepted by all Execution API endpoints. This is the standard scope for
worker
- communication and allows access
+ A short-lived token (default 10 minutes) accepted by all Execution API
endpoints.
+ This is the standard scope for worker communication during task execution.
Issued
+ by the server when the worker transitions to running via the ``/run``
endpoint.
+ The ``JWTReissueMiddleware`` refreshes ``execution`` tokens transparently,
+ so the worker maintains access for the duration of the task.
Tokens without a ``scope`` claim default to ``"execution"`` for backwards
compatibility.
@@ -219,14 +228,19 @@ Token delivery to workers
The token flows through the execution stack as follows:
-1. **Scheduler** generates the token and embeds it in the workload JSON
payload that it passes to
- **Executor**.
+1. **Scheduler** generates a ``workload``-scoped token (lifetime equals
+ ``[scheduler] task_queued_timeout``, default 600 seconds) and embeds it in
the workload
+ JSON payload that it passes to **Executor**.
2. The workload JSON is passed to the worker process (via the
executor-specific mechanism:
Celery message, Kubernetes Pod spec, local subprocess arguments, etc.).
3. The worker's ``execute_workload()`` function reads the workload JSON and
extracts the token.
4. The ``supervise()`` function receives the token and creates an
``httpx.Client`` instance
with ``BearerAuth(token)`` for all Execution API HTTP requests.
-5. The token is included in the ``Authorization: Bearer <token>`` header of
every request.
+5. The worker calls the ``/run`` endpoint with the ``workload``-scoped token
to mark the task
+ as running. The server responds with a fresh ``execution``-scoped token in
the
+ ``Refreshed-API-Token`` header.
+6. The client's ``_update_auth()`` hook detects the header and transparently
updates
+ the ``BearerAuth`` instance to use the new ``execution`` token for all
subsequent requests.
Token validation (Execution API)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -251,7 +265,8 @@ Route-level enforcement is handled by ``require_auth``:
Token refresh (Execution API)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-The ``JWTReissueMiddleware`` automatically refreshes valid tokens that are
approaching expiry:
+The ``JWTReissueMiddleware`` automatically refreshes valid tokens that are
approaching
+expiry. The token must be valid at the start of the request for refresh to
occur:
1. After each response, the middleware checks the token's remaining validity.
2. If less than **20%** of the total validity remains (minimum 30 seconds),
the server
@@ -260,16 +275,20 @@ The ``JWTReissueMiddleware`` automatically refreshes
valid tokens that are appro
4. The client's ``_update_auth()`` hook detects this header and transparently
updates
the ``BearerAuth`` instance for subsequent requests.
-This mechanism ensures long-running tasks do not lose API access due to token
expiry,
-without requiring the worker to re-authenticate.
+The middleware only refreshes ``execution``-scoped tokens. ``workload``-scoped
tokens are
+sized to span the queued-timeout window and are explicitly skipped by the
middleware —
+they are designed to survive executor queue wait times without needing
refresh. This
+ensures long-running tasks do not lose API access without requiring the worker
to
+re-authenticate.
No token revocation (Execution API)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-Execution API tokens are not subject to revocation. They are short-lived
(default 10 minutes)
-and automatically refreshed by the ``JWTReissueMiddleware``, so revocation is
not part of the
-Execution API security model. Once an Execution API token is issued to a
worker, it remains
-valid until it expires.
+Execution API tokens are not subject to revocation. ``execution``-scoped
tokens are short-lived
+(default 10 minutes) and automatically refreshed by the
``JWTReissueMiddleware``.
+``workload``-scoped tokens (tracking ``[scheduler] task_queued_timeout``) are
not refreshed —
+they expire naturally after their validity period. Revocation is not part of
the Execution API
+security model.
@@ -284,11 +303,12 @@ Default timings (Execution API)
- Default
* - ``[execution_api] jwt_expiration_time``
- 600 seconds (10 minutes)
+ * - Workload token lifetime (derived)
+ - ``[scheduler] task_queued_timeout`` (default 600 seconds)
* - ``[execution_api] jwt_audience``
- ``urn:airflow.apache.org:task``
* - Token refresh threshold
- - 20% of validity remaining (minimum 30 seconds, i.e., at ~120 seconds
before expiry
- with the default 600-second token lifetime)
+ - 20% of validity remaining (minimum 30 seconds)
Dag File Processor and Triggerer
@@ -386,7 +406,10 @@ All JWT-related configuration parameters:
- JWKS endpoint URL or local file path for token validation. Mutually
exclusive with ``jwt_secret``.
* - ``[execution_api] jwt_expiration_time``
- 600 (10 min)
- - Execution API token lifetime in seconds.
+ - Execution API ``execution``-scoped token lifetime in seconds.
+ * - ``[scheduler] task_queued_timeout``
+ - 600.0 (10 min)
+ - Queue-starvation timeout. Also sets the ``workload``-scoped token
lifetime to the same value.
* - ``[execution_api] jwt_audience``
- ``urn:airflow.apache.org:task``
- Audience claim for Execution API tokens.
diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py
b/airflow-core/src/airflow/api_fastapi/auth/tokens.py
index 3375853a29a..707d427101c 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py
@@ -447,15 +447,21 @@ class JWTGenerator:
assert self._secret_key
return self._secret_key
- def generate(self, extras: dict[str, Any] | None = None, headers:
dict[str, Any] | None = None) -> str:
+ def generate(
+ self,
+ extras: dict[str, Any] | None = None,
+ headers: dict[str, Any] | None = None,
+ valid_for: float | None = None,
+ ) -> str:
"""Generate a signed JWT for the subject."""
now = int(datetime.now(tz=timezone.utc).timestamp())
+ effective_valid_for = valid_for if valid_for is not None else
self.valid_for
claims = {
"jti": uuid.uuid4().hex,
"iss": self.issuer,
"aud": self.audience,
"nbf": now,
- "exp": int(now + self.valid_for),
+ "exp": int(now + effective_valid_for),
"iat": now,
}
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index 4f1f671cd83..8dbdf25f7e0 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -129,8 +129,6 @@ class CorrelationIdMiddleware(BaseHTTPMiddleware):
class JWTReissueMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
- from airflow.configuration import conf
-
response: Response = await call_next(request)
refreshed_token: str | None = None
@@ -142,9 +140,15 @@ class JWTReissueMiddleware(BaseHTTPMiddleware):
validator: JWTValidator = await services.aget(JWTValidator)
claims = await validator.avalidated_claims(token, {})
+ # Workload tokens are long-lived and meant to survive queue
+ # wait times so avoid refreshing them. If avalidated_claims
+ # raises for a workload token, the outer except handles it.
+ if claims.get("scope") == "workload":
+ return response
+
now = int(time.time())
- validity = conf.getint("execution_api",
"jwt_expiration_time")
- refresh_when_less_than = max(int(validity * 0.20), 30)
+ token_lifetime = int(claims.get("exp", 0)) -
int(claims.get("iat", 0))
+ refresh_when_less_than = max(int(token_lifetime * 0.20),
30)
valid_left = int(claims.get("exp", 0)) - now
if valid_left <= refresh_when_less_than:
generator: JWTGenerator = await
services.aget(JWTGenerator)
@@ -312,7 +316,6 @@ class InProcessExecutionAPI:
def app(self):
if not self._app:
from airflow.api_fastapi.common.dagbag import create_dag_bag
- from airflow.api_fastapi.execution_api.app import
create_task_execution_api_app
from airflow.api_fastapi.execution_api.datamodels.token import
TIClaims, TIToken
from airflow.api_fastapi.execution_api.routes.connections import
has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import
has_variable_access
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 0b17c97530a..d4a1309093a 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -28,7 +28,7 @@ from uuid import UUID
import attrs
import structlog
from cadwyn import VersionedAPIRouter
-from fastapi import Body, HTTPException, Query, Security, status
+from fastapi import Body, HTTPException, Query, Response, Security, status
from opentelemetry import trace
from opentelemetry.trace import StatusCode
from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
@@ -42,6 +42,7 @@ from structlog.contextvars import bind_contextvars
from airflow._shared.observability.traces import override_ids
from airflow._shared.timezones import timezone
+from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.dagbag import DagBagDep,
get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
@@ -63,7 +64,9 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TISuccessStatePayload,
TITerminalStatePayload,
)
-from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute,
require_auth
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.deps import DepContainer
+from airflow.api_fastapi.execution_api.security import CurrentTIToken,
ExecutionAPIRoute, require_auth
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dag import DagModel
@@ -98,6 +101,7 @@ tracer = trace.get_tracer(__name__)
@ti_id_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
+ dependencies=[Security(require_auth, scopes=["token:execution",
"token:workload"])],
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the
requested state"},
@@ -108,8 +112,11 @@ tracer = trace.get_tracer(__name__)
def ti_run(
task_instance_id: UUID,
ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
+ response: Response,
session: SessionDep,
dag_bag: DagBagDep,
+ services=DepContainer,
+ token: TIToken = CurrentTIToken,
) -> TIRunContext:
"""
Run a TaskInstance.
@@ -289,13 +296,20 @@ def ti_run(
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs
context.start_date = ti.start_date
- return context
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Database error occurred"
)
+ # JWTReissueMiddleware also writes Refreshed-API-Token but skips workload
tokens, so we set it here for the workload→execution swap.
+ if token.claims.scope == "workload":
+ generator: JWTGenerator = services.get(JWTGenerator)
+ execution_token = generator.generate(extras={"sub":
str(task_instance_id), "scope": "execution"})
+ response.headers["Refreshed-API-Token"] = execution_token
+
+ return context
+
@ti_id_router.patch(
"/{task_instance_id}/state",
diff --git a/airflow-core/src/airflow/config_templates/config.yml
b/airflow-core/src/airflow/config_templates/config.yml
index 99eb360ab72..bbb137880ba 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -2576,6 +2576,10 @@ scheduler:
task_queued_timeout:
description: |
Amount of time a task can be in the queued state before being retried
or set to failed.
+
+ This value also sets the lifetime of the workload JWT token that is
sent with the task
+ to an executor queue, so a task waiting in the queue can still
authenticate to the
+ Execution API until its queue-starvation deadline.
version_added: 2.6.0
type: float
example: ~
diff --git a/airflow-core/src/airflow/executors/workloads/base.py
b/airflow-core/src/airflow/executors/workloads/base.py
index 97cf16ebaf6..1cbf71b9595 100644
--- a/airflow-core/src/airflow/executors/workloads/base.py
+++ b/airflow-core/src/airflow/executors/workloads/base.py
@@ -24,6 +24,8 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field
+from airflow.configuration import conf
+
if TYPE_CHECKING:
from airflow.api_fastapi.auth.tokens import JWTGenerator
@@ -74,7 +76,13 @@ class BaseWorkloadSchema(BaseModel):
@staticmethod
def generate_token(sub_id: str, generator: JWTGenerator | None = None) ->
str:
- return generator.generate({"sub": sub_id}) if generator else ""
+ if not generator:
+ return ""
+ valid_for = conf.getfloat("scheduler", "task_queued_timeout")
+ return generator.generate(
+ extras={"sub": sub_id, "scope": "workload"},
+ valid_for=valid_for,
+ )
class BaseDagBundleWorkload(BaseWorkloadSchema, ABC):
diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
index 8b76dff6217..4dfd186756b 100644
--- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
+++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
@@ -160,6 +160,34 @@ def test_secret_key_with_configured_kid():
assert header["kid"] == "my-custom-kid"
+def test_generate_with_custom_valid_for():
+ """generate() accepts a valid_for override."""
+ generator = JWTGenerator(secret_key="test-secret", audience="test",
valid_for=60)
+ token = generator.generate(extras={"sub": "user"}, valid_for=3600)
+ claims = jwt.decode(token, "test-secret", algorithms=["HS512"],
audience="test")
+ assert claims["exp"] - claims["iat"] == 3600
+
+
+def test_generate_workload_scope_via_extras():
+ """generate() with scope='workload' in extras produces a workload-scoped
token."""
+ generator = JWTGenerator(secret_key="test-secret", audience="test",
valid_for=60)
+
+ token = generator.generate(extras={"sub": "ti-123", "scope": "workload"},
valid_for=86400)
+ claims = jwt.decode(token, "test-secret", algorithms=["HS512"],
audience="test")
+ assert claims["sub"] == "ti-123"
+ assert claims["scope"] == "workload"
+ assert claims["exp"] - claims["iat"] == 86400
+
+
+def test_regular_token_has_no_scope():
+ """Regular tokens without scope in extras have no scope claim."""
+ generator = JWTGenerator(secret_key="test-secret", audience="test",
valid_for=60)
+
+ regular = generator.generate(extras={"sub": "user"})
+ regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"],
audience="test")
+ assert "scope" not in regular_claims
+
+
@pytest.fixture
def jwt_generator(ed25519_private_key: Ed25519PrivateKey):
key = ed25519_private_key
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
index 20eed734701..1afd038c039 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
@@ -22,8 +22,16 @@ from fastapi.testclient import TestClient
from starlette.routing import Mount
from airflow.api_fastapi.app import cached_app
+from airflow.api_fastapi.execution_api.app import lifespan
from airflow.api_fastapi.execution_api.datamodels.token import TIClaims,
TIToken
-from airflow.api_fastapi.execution_api.security import _jwt_bearer
+from airflow.api_fastapi.execution_api.security import require_auth
+
+
[email protected](autouse=True)
+def _restore_lifespan_registry():
+ snapshot = dict(lifespan.registry._services)
+ yield
+ lifespan.registry._services = snapshot
def _get_execution_api_app(root_app: FastAPI) -> FastAPI:
@@ -45,16 +53,15 @@ def client(request: pytest.FixtureRequest):
app = cached_app(apps="execution")
exec_app = _get_execution_api_app(app)
- async def mock_jwt_bearer(request: Request):
+ async def mock_require_auth(request: Request) -> TIToken:
from uuid import UUID
ti_id = UUID(request.path_params.get("task_instance_id",
"00000000-0000-0000-0000-000000000000"))
- claims = TIClaims(scope="execution")
- return TIToken(id=ti_id, claims=claims)
+ return TIToken(id=ti_id, claims=TIClaims(scope="execution"))
- exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer
+ exec_app.dependency_overrides[require_auth] = mock_require_auth
with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
yield client
- exec_app.dependency_overrides.pop(_jwt_bearer, None)
+ exec_app.dependency_overrides.pop(require_auth, None)
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
index 85f7df46915..3a3e669f7e3 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
@@ -25,8 +25,6 @@ from fastapi import FastAPI
from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
-from tests_common.test_utils.config import conf_vars
-
@pytest.fixture
def exec_app(client):
@@ -53,6 +51,7 @@ def test_expiring_token_is_reissued(
auth = AsyncMock(spec=JWTValidator)
auth.avalidated_claims.return_value = {
"sub": "edb09971-4e0e-4221-ad3f-800852d38085",
+ "iat": moment,
"exp": moment + validity,
}
@@ -62,8 +61,7 @@ def test_expiring_token_is_reissued(
lifespan.registry.register_value(JWTValidator, auth)
# In order to test this we need any endpoint to hit. The easiest one to
use is variable get
- with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
- response = client.get("/execution/variables/key1",
headers={"Authorization": "Bearer dummy"})
+ response = client.get("/execution/variables/key1",
headers={"Authorization": "Bearer dummy"})
if expect_refreshed_token:
assert "Refreshed-API-Token" in response.headers
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 6066dba5dc1..d0f73fc1617 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -24,6 +24,7 @@ from uuid import UUID, uuid4
import pytest
import uuid6
+from fastapi import Request
from opentelemetry import trace as otel_trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
@@ -36,9 +37,11 @@ from sqlalchemy.orm import Session
from airflow._shared.observability.traces import OverrideableRandomIdGenerator
from airflow._shared.timezones import timezone
-from airflow.api_fastapi.auth.tokens import JWTValidator
+from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
+from airflow.api_fastapi.execution_api.datamodels.token import TIClaims,
TIToken
from airflow.api_fastapi.execution_api.routes.task_instances import
_emit_task_span
+from airflow.api_fastapi.execution_api.security import require_auth
from airflow.exceptions import AirflowSkipException
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent,
AssetModel
@@ -111,10 +114,8 @@ def _create_asset_aliases(session, num: int = 2) -> None:
@pytest.fixture
def _use_real_jwt_bearer(exec_app):
- """Remove the mock jwt_bearer override so the real JWTBearer.__call__
runs."""
- from airflow.api_fastapi.execution_api.security import _jwt_bearer
-
- exec_app.dependency_overrides.pop(_jwt_bearer, None)
+ """Remove the mock require_auth override so the real JWT validation runs
end-to-end."""
+ exec_app.dependency_overrides.pop(require_auth, None)
@pytest.mark.usefixtures("_use_real_jwt_bearer")
@@ -242,6 +243,8 @@ class TestTIRunState:
}
# upstream_map_indexes is now computed by Task SDK, not returned by
the server in HEAD version
assert "upstream_map_indexes" not in result
+ # execution-scoped tokens do not trigger a token swap
+ assert "Refreshed-API-Token" not in response.headers
# Refresh the Task Instance from the database so that we can check the
updated values
session.refresh(ti)
@@ -281,6 +284,54 @@ class TestTIRunState:
)
assert response.status_code == 409
+ def test_ti_run_returns_execution_token(
+ self, client, exec_app, session, create_task_instance, time_machine
+ ):
+ """PATCH /run with a workload token should swap to an execution-scoped
token."""
+ instant = timezone.parse("2024-10-31T12:00:00Z")
+ time_machine.move_to(instant, tick=False)
+
+ ti = create_task_instance(
+ task_id="test_exec_token",
+ state=State.QUEUED,
+ dagrun_state=DagRunState.RUNNING,
+ session=session,
+ start_date=instant,
+ dag_id=str(uuid4()),
+ )
+ session.commit()
+
+ mock_gen = mock.MagicMock(spec=JWTGenerator)
+ mock_gen.generate.return_value = "mock-execution-token"
+ lifespan.registry.register_value(JWTGenerator, mock_gen)
+
+ async def workload_token(request: Request) -> TIToken:
+ ti_id = UUID(request.path_params.get("task_instance_id",
"00000000-0000-0000-0000-000000000000"))
+ return TIToken(id=ti_id, claims=TIClaims(scope="workload"))
+
+ exec_app.dependency_overrides[require_auth] = workload_token
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "test-host",
+ "unixname": "test-user",
+ "pid": 100,
+ "start_date": "2024-10-31T12:00:00Z",
+ },
+ )
+
+ exec_app.dependency_overrides.pop(require_auth, None)
+
+ assert response.status_code == 200
+ assert "Refreshed-API-Token" in response.headers
+ assert response.headers["Refreshed-API-Token"] ==
"mock-execution-token"
+ mock_gen.generate.assert_called_once()
+ extras = mock_gen.generate.call_args.kwargs["extras"]
+ assert extras["scope"] == "execution"
+ assert extras["sub"] == str(ti.id)
+
def test_dynamic_task_mapping_with_parse_time_value(self, client,
dag_maker):
"""Test that dynamic task mapping works correctly with parse-time
values."""
with dag_maker("test_dynamic_task_mapping_with_parse_time_value",
serialized=True):
@@ -3439,40 +3490,56 @@ class TestTIPatchRenderedMapIndex:
class TestTokenTypeValidation:
"""Test token scope enforcement (workload vs execution)."""
- def test_workload_scope_rejected_on_default_endpoints(self, client,
session, create_task_instance):
- """workload scoped tokens should be rejected on endpoints without
token:workload Security scope."""
+ def _register_scoped_validator(self, ti_id, scope):
+ """Register a JWTValidator mock returning claims with the given
scope."""
+ validator = mock.AsyncMock(spec=JWTValidator)
+ claims = {"sub": str(ti_id), "exp": 9999999999, "iat": 1000000000,
"nbf": 1000000000}
+ if scope is not None:
+ claims["scope"] = scope
+ validator.avalidated_claims.side_effect = lambda cred, validators:
claims
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ def test_workload_scope_rejected_on_heartbeat_endpoint(self, client,
session, create_task_instance):
+ """Workload scoped tokens should be rejected on /heartbeat."""
ti = create_task_instance(task_id="test_ti_run_heartbeat",
state=State.RUNNING)
session.commit()
- validator = mock.AsyncMock(spec=JWTValidator)
- validator.avalidated_claims.side_effect = lambda cred, validators: {
- "sub": str(ti.id),
- "scope": "workload",
- "exp": 9999999999,
- "iat": 1000000000,
- "nbf": 1000000000,
- }
- lifespan.registry.register_value(JWTValidator, validator)
+ self._register_scoped_validator(ti.id, "workload")
payload = {"hostname": "test-host", "pid": 100}
resp = client.put(f"/execution/task-instances/{ti.id}/heartbeat",
json=payload)
assert resp.status_code == 403
assert "Token type 'workload' not allowed" in resp.json()["detail"]
+ def test_workload_scope_rejected_on_state_endpoint(self, client, session,
create_task_instance):
+ """Workload scoped tokens should be rejected on PATCH /state."""
+ ti = create_task_instance(task_id="test_workload_state",
state=State.RUNNING)
+ session.commit()
+
+ self._register_scoped_validator(ti.id, "workload")
+
+ payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"}
+ resp = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
+ assert resp.status_code == 403
+ assert "Token type 'workload' not allowed" in resp.json()["detail"]
+
+ def test_workload_scope_rejected_on_connections_endpoint(self, client,
session, create_task_instance):
+ """Workload scoped tokens should be rejected on GET /connections
(different router)."""
+ ti = create_task_instance(task_id="test_workload_conn",
state=State.RUNNING)
+ session.commit()
+
+ self._register_scoped_validator(ti.id, "workload")
+
+ resp = client.get("/execution/connections/test_conn")
+ assert resp.status_code == 403
+ assert "Token type 'workload' not allowed" in resp.json()["detail"]
+
def test_execution_scope_accepted_on_all_endpoints(self, client, session,
create_task_instance):
- """execution scoped tokens should be able to call all endpoints."""
+ """Execution scoped tokens should be accepted on all endpoints."""
ti = create_task_instance(task_id="test_ti_star", state=State.RUNNING)
session.commit()
- validator = mock.AsyncMock(spec=JWTValidator)
- validator.avalidated_claims.side_effect = lambda cred, validators: {
- "sub": str(ti.id),
- "scope": "execution",
- "exp": 9999999999,
- "iat": 1000000000,
- "nbf": 1000000000,
- }
- lifespan.registry.register_value(JWTValidator, validator)
+ self._register_scoped_validator(ti.id, "execution")
payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"}
resp = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
@@ -3483,15 +3550,7 @@ class TestTokenTypeValidation:
ti = create_task_instance(task_id="test_invalid_scope",
state=State.QUEUED)
session.commit()
- validator = mock.AsyncMock(spec=JWTValidator)
- validator.avalidated_claims.side_effect = lambda cred, validators: {
- "sub": str(ti.id),
- "scope": "bogus:scope",
- "exp": 9999999999,
- "iat": 1000000000,
- "nbf": 1000000000,
- }
- lifespan.registry.register_value(JWTValidator, validator)
+ self._register_scoped_validator(ti.id, "bogus:scope")
payload = {
"state": "running",
@@ -3505,19 +3564,43 @@ class TestTokenTypeValidation:
assert resp.status_code == 403
assert "Invalid auth token" in resp.json()["detail"]
+ def test_workload_scope_accepted_on_run_endpoint(
+ self, client, session, create_task_instance, time_machine
+ ):
+ """Workload scoped tokens should be accepted on the /run endpoint."""
+ instant = timezone.parse("2024-10-31T12:00:00Z")
+ time_machine.move_to(instant, tick=False)
+
+ ti = create_task_instance(
+ task_id="test_workload_run",
+ state=State.QUEUED,
+ dagrun_state=DagRunState.RUNNING,
+ session=session,
+ start_date=instant,
+ dag_id=str(uuid4()),
+ )
+ session.commit()
+
+ self._register_scoped_validator(ti.id, "workload")
+
+ resp = client.patch(
+ f"/execution/task-instances/{ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "test-host",
+ "unixname": "test-user",
+ "pid": 100,
+ "start_date": "2024-10-31T12:00:00Z",
+ },
+ )
+ assert resp.status_code == 200
+
def test_no_scope_defaults_to_execution(self, client, session,
create_task_instance):
"""Tokens without scope claim should default to 'execution'."""
ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING)
session.commit()
- validator = mock.AsyncMock(spec=JWTValidator)
- validator.avalidated_claims.side_effect = lambda cred, validators: {
- "sub": str(ti.id),
- "exp": 9999999999,
- "iat": 1000000000,
- "nbf": 1000000000,
- }
- lifespan.registry.register_value(JWTValidator, validator)
+ self._register_scoped_validator(ti.id, None)
payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"}
resp = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
diff --git a/airflow-core/tests/unit/executors/test_workloads.py
b/airflow-core/tests/unit/executors/test_workloads.py
index 1a67ab96d40..2c3ffbf53ea 100644
--- a/airflow-core/tests/unit/executors/test_workloads.py
+++ b/airflow-core/tests/unit/executors/test_workloads.py
@@ -20,9 +20,12 @@ from __future__ import annotations
from pathlib import PurePosixPath
from uuid import uuid4
+import jwt
+
+from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.executors import workloads
-from airflow.executors.workloads import TaskInstance, TaskInstanceDTO
-from airflow.executors.workloads.base import BundleInfo
+from airflow.executors.workloads import TaskInstance, TaskInstanceDTO, base as
workloads_base
+from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo
from airflow.executors.workloads.task import ExecuteTask
@@ -61,3 +64,21 @@ def test_token_excluded_from_workload_repr():
assert fake_token not in workload_repr, f"JWT token leaked into repr!
Found token in: {workload_repr}"
# But token should still be accessible as an attribute
assert workload.token == fake_token
+
+
+def test_generate_token_produces_workload_scope(monkeypatch):
+ """generate_token should create a JWT with scope 'workload' and
[scheduler] task_queued_timeout expiry."""
+ monkeypatch.setattr(workloads_base.conf, "getfloat", lambda section, key:
86400.0)
+
+ generator = JWTGenerator(secret_key="test-secret", audience="test",
valid_for=60)
+ token = BaseWorkloadSchema.generate_token("ti-123", generator)
+
+ claims = jwt.decode(token, "test-secret", algorithms=["HS512"],
audience="test")
+ assert claims["sub"] == "ti-123"
+ assert claims["scope"] == "workload"
+ assert claims["exp"] - claims["iat"] == 86400
+
+
+def test_generate_token_without_generator():
+ """generate_token should return empty string when no generator is
provided."""
+ assert BaseWorkloadSchema.generate_token("ti-123", None) == ""
diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py
b/devel-common/src/tests_common/test_utils/mock_executor.py
index 4e95ed3a4ee..c7a2f263152 100644
--- a/devel-common/src/tests_common/test_utils/mock_executor.py
+++ b/devel-common/src/tests_common/test_utils/mock_executor.py
@@ -22,6 +22,7 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
+from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_utils import ExecutorName
from airflow.models.taskinstance import TaskInstance
@@ -57,7 +58,7 @@ class MockExecutor(BaseExecutor):
self.mock_task_results = defaultdict(self.success)
# Mock JWT generator for token generation
- mock_jwt_generator = MagicMock()
+ mock_jwt_generator = MagicMock(spec=JWTGenerator)
mock_jwt_generator.generate.return_value = "mock-token"
self.jwt_generator = mock_jwt_generator