This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c2a67df7fb6 Introduce API versioning into the Execution API (#47951)
c2a67df7fb6 is described below

commit c2a67df7fb66a7defdb1ff07ef1dc894920e6caa
Author: Ash Berlin-Taylor <a...@apache.org>
AuthorDate: Thu Mar 20 11:59:49 2025 +0000

    Introduce API versioning into the Execution API (#47951)
    
    Right now we have exactly one version, but using Cadwyn gives us the ability
    to make changes to the Execution API almost "at will" and it will
    transparently upgrade from older request versions. The tl;dr of it is "apply
    migrations to your requests so that you only have to maintain the latest
    version".
    
    From it's docs[^1]:
    
    > Cadwyn allows you to support a single version of your code while
    > auto-generating the schemas and routes for older versions. You keep API
    > versioning encapsulated in small and independent "version change" modules
    > while your business logic stays simple and knows nothing about versioning.
    
    This gives us the freedom to change the version of the API but continue to
    support "all" old versions. This is a key aspect of proposal in AIP-72. We
    don't yet have any new versions -- those will come later.
    
    [^1]: https://docs.cadwyn.dev/
---
 .pre-commit-config.yaml                            |  2 +-
 airflow/api_fastapi/execution_api/app.py           | 74 ++++++++++++++--------
 .../execution_api/datamodels/variable.py           |  8 +--
 .../api_fastapi/execution_api/datamodels/xcom.py   |  8 +++
 .../api_fastapi/execution_api/routes/__init__.py   |  3 +-
 .../execution_api/routes/task_instances.py         |  4 +-
 airflow/api_fastapi/execution_api/routes/xcoms.py  | 17 +++--
 dev/datamodel_code_formatter.py                    | 31 +++++++++
 hatch_build.py                                     |  3 +-
 task-sdk/dev/generate_models.py                    | 27 ++++----
 task-sdk/src/airflow/sdk/api/client.py             |  8 ++-
 .../src/airflow/sdk/api/datamodels/_generated.py   | 29 ++-------
 tests/api_fastapi/execution_api/conftest.py        |  6 +-
 .../execution_api/routes/test_health.py            |  4 +-
 .../execution_api/routes/test_task_instances.py    | 61 ++++++++++--------
 .../execution_api/routes/test_variables.py         |  2 +-
 tests/api_fastapi/execution_api/test_app.py        |  2 +-
 17 files changed, 175 insertions(+), 114 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index a7eff537f15..5dba2a99cc4 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1353,7 +1353,7 @@ repos:
       - id: generate-tasksdk-datamodels
         name: Generate Datamodels for TaskSDK client
         language: python
-        entry: uv run --active --group codegen --project 
apache-airflow-task-sdk --directory task-sdk -s dev/generate_models.py
+        entry: uv run -p 3.12 --no-progress --active --group codegen --project 
apache-airflow-task-sdk --directory task-sdk -s dev/generate_models.py
         pass_filenames: false
         files: ^airflow/api_fastapi/execution_api/.*\.py$
         require_serial: true
diff --git a/airflow/api_fastapi/execution_api/app.py 
b/airflow/api_fastapi/execution_api/app.py
index 12ed2de27e2..3546a57eeb4 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -17,14 +17,20 @@
 
 from __future__ import annotations
 
+import json
 from contextlib import AsyncExitStack
 from functools import cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
 import attrs
 import svcs
+from cadwyn import (
+    Cadwyn,
+    HeadVersion,
+    Version,
+    VersionBundle,
+)
 from fastapi import FastAPI, Request
-from fastapi.openapi.utils import get_openapi
 from fastapi.responses import JSONResponse
 
 from airflow.api_fastapi.auth.tokens import JWTValidator, 
get_sig_validation_args
@@ -36,6 +42,11 @@ import structlog
 
 logger = structlog.get_logger(logger_name=__name__)
 
+__all__ = [
+    "create_task_execution_api_app",
+    "lifespan",
+]
+
 
 def _jwt_validator():
     from airflow.configuration import conf
@@ -67,18 +78,18 @@ async def lifespan(app: FastAPI, registry: svcs.Registry):
     yield
 
 
-def create_task_execution_api_app() -> FastAPI:
-    """Create FastAPI app for task execution API."""
-    from airflow.api_fastapi.execution_api.routes import execution_api_router
+class CadwynWithOpenAPICustomization(Cadwyn):
+    # Workaround lack of customzation 
https://github.com/zmievsa/cadwyn/issues/255
+    async def openapi_jsons(self, req: Request) -> JSONResponse:
+        resp = await super().openapi_jsons(req)
+        open_apischema = json.loads(resp.body)  # type: ignore[arg-type]
+        open_apischema = self.customize_openapi(open_apischema)
 
-    # TODO: Add versioning to the API
-    app = FastAPI(
-        title="Airflow Task Execution API",
-        description="The private Airflow Task Execution API.",
-        lifespan=lifespan,
-    )
+        resp.body = resp.render(open_apischema)
 
-    def custom_openapi() -> dict:
+        return resp
+
+    def customize_openapi(self, openapi_schema: dict[str, Any]) -> dict[str, 
Any]:
         """
         Customize the OpenAPI schema to include additional schemas not tied to 
specific endpoints.
 
@@ -88,16 +99,6 @@ def create_task_execution_api_app() -> FastAPI:
         References:
             - 
https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
         """
-        if app.openapi_schema:
-            return app.openapi_schema
-        openapi_schema = get_openapi(
-            title=app.title,
-            description=app.description,
-            version=app.version,
-            routes=app.routes,
-            servers=app.servers,
-        )
-
         extra_schemas = get_extra_schemas()
         for schema_name, schema in extra_schemas.items():
             if schema_name not in openapi_schema["components"]["schemas"]:
@@ -111,12 +112,33 @@ def create_task_execution_api_app() -> FastAPI:
             ],
         }
 
-        app.openapi_schema = openapi_schema
-        return app.openapi_schema
+        for comp in openapi_schema["components"]["schemas"].values():
+            for prop in comp.get("properties", {}).values():
+                # {"type": "string", "const": "deferred"}
+                # to
+                # {"type": "string", "enum": ["deferred"]}
+                #
+                # this produces better results in the code generator
+                if prop.get("type") == "string" and (const := 
prop.pop("const", None)):
+                    prop["enum"] = [const]
 
-    app.openapi = custom_openapi  # type: ignore[method-assign]
+        return openapi_schema
+
+
+def create_task_execution_api_app() -> FastAPI:
+    """Create FastAPI app for task execution API."""
+    from airflow.api_fastapi.execution_api.routes import execution_api_router
+
+    # See https://docs.cadwyn.dev/concepts/version_changes/ for info about API 
versions
+    app = CadwynWithOpenAPICustomization(
+        title="Airflow Task Execution API",
+        description="The private Airflow Task Execution API.",
+        lifespan=lifespan,
+        api_version_parameter_name="Airflow-API-Version",
+        versions=VersionBundle(HeadVersion(), Version("2025-03-19")),
+    )
 
-    app.include_router(execution_api_router)
+    app.generate_and_include_versioned_routers(execution_api_router)
 
     # As we are mounted as a sub app, we don't get any logs for unhandled 
exceptions without this!
     @app.exception_handler(Exception)
diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py 
b/airflow/api_fastapi/execution_api/datamodels/variable.py
index 73361908a81..fd49a5eae46 100644
--- a/airflow/api_fastapi/execution_api/datamodels/variable.py
+++ b/airflow/api_fastapi/execution_api/datamodels/variable.py
@@ -19,10 +19,10 @@ from __future__ import annotations
 
 from pydantic import Field
 
-from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict, 
StrictBaseModel
+from airflow.api_fastapi.core_api.base import StrictBaseModel
 
 
-class VariableResponse(BaseModel):
+class VariableResponse(StrictBaseModel):
     """Variable schema for responses with fields that are needed for 
Runtime."""
 
     key: str
@@ -32,7 +32,5 @@ class VariableResponse(BaseModel):
 class VariablePostBody(StrictBaseModel):
     """Request body schema for creating variables."""
 
-    model_config = ConfigDict(extra="forbid")
-
-    value: str | None = Field(serialization_alias="val")
+    value: str | None = Field(alias="val")
     description: str | None = Field(default=None)
diff --git a/airflow/api_fastapi/execution_api/datamodels/xcom.py 
b/airflow/api_fastapi/execution_api/datamodels/xcom.py
index de4a03e811c..ae7ddd26761 100644
--- a/airflow/api_fastapi/execution_api/datamodels/xcom.py
+++ b/airflow/api_fastapi/execution_api/datamodels/xcom.py
@@ -17,10 +17,18 @@
 
 from __future__ import annotations
 
+import sys
+from typing import Any
+
 from pydantic import JsonValue
 
 from airflow.api_fastapi.core_api.base import BaseModel
 
+if sys.version_info < (3, 12):
+    # zmievsa/cadwyn#262
+    # Setting this to "Any" doesn't have any impact on the API as it has to be 
parsed as valid JSON regardless
+    JsonValue = Any  # type: ignore [misc]
+
 
 class XComResponse(BaseModel):
     """XCom schema for responses with fields that are needed for Runtime."""
diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py 
b/airflow/api_fastapi/execution_api/routes/__init__.py
index 0c21b9c0535..6610d5552f2 100644
--- a/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from cadwyn import VersionedAPIRouter
 from fastapi import APIRouter
 
 from airflow.api_fastapi.execution_api.deps import JWTBearerDep
@@ -34,7 +35,7 @@ execution_api_router = APIRouter()
 execution_api_router.include_router(health.router, prefix="/health", 
tags=["Health"])
 
 # _Every_ single endpoint under here must be authenticated. Some do further 
checks
-authenticated_router = APIRouter(dependencies=[JWTBearerDep])  # type: 
ignore[list-item]
+authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep])  # 
type: ignore[list-item]
 
 authenticated_router.include_router(assets.router, prefix="/assets", 
tags=["Assets"])
 authenticated_router.include_router(asset_events.router, 
prefix="/asset-events", tags=["Asset Events"])
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 98b24d28d3b..6b031e36394 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -22,6 +22,7 @@ import logging
 from typing import Annotated
 from uuid import UUID
 
+from cadwyn import VersionedAPIRouter
 from fastapi import Body, Depends, HTTPException, status
 from pydantic import JsonValue
 from sqlalchemy import func, tuple_, update
@@ -29,7 +30,6 @@ from sqlalchemy.exc import NoResultFound, SQLAlchemyError
 from sqlalchemy.sql import select
 
 from airflow.api_fastapi.common.db.common import SessionDep
-from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     PrevSuccessfulDagRunResponse,
     TIDeferredStatePayload,
@@ -52,7 +52,7 @@ from airflow.models.xcom import XComModel
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
 
-router = AirflowRouter(
+router = VersionedAPIRouter(
     dependencies=[
         # This checks that the UUID in the url matches the one in the token 
for us.
         Depends(JWTBearer(path_param_name="task_instance_id")),
diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py 
b/airflow/api_fastapi/execution_api/routes/xcoms.py
index 3612f3615ab..e75d111c0fe 100644
--- a/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -18,7 +18,8 @@
 from __future__ import annotations
 
 import logging
-from typing import Annotated
+import sys
+from typing import Annotated, Any
 
 from fastapi import Body, Depends, HTTPException, Path, Query, Request, 
Response, status
 from pydantic import JsonValue
@@ -60,6 +61,7 @@ router = AirflowRouter(
     responses={
         status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
         status.HTTP_403_FORBIDDEN: {"description": "Task does not have access 
to the XCom"},
+        status.HTTP_404_NOT_FOUND: {"description": "XCom not found"},
     },
     dependencies=[Depends(has_xcom_access)],
 )
@@ -89,7 +91,6 @@ async def xcom_query(
 @router.head(
     "/{dag_id}/{run_id}/{task_id}/{key}",
     responses={
-        status.HTTP_404_NOT_FOUND: {"description": "XCom not found"},
         status.HTTP_200_OK: {
             "description": "Metadata about the number of matching XCom values",
             "headers": {
@@ -100,7 +101,7 @@ async def xcom_query(
             },
         },
     },
-    description="Return the count of the number of XCom values found via the 
Content-Range response header",
+    description="Returns the count of mapped XCom values found in the 
`Content-Range` response header",
 )
 def head_xcom(
     response: Response,
@@ -123,7 +124,6 @@ def head_xcom(
 
 @router.get(
     "/{dag_id}/{run_id}/{task_id}/{key}",
-    responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
     description="Get a single XCom Value",
 )
 def get_xcom(
@@ -156,14 +156,17 @@ def get_xcom(
     return XComResponse(key=key, value=result.value)
 
 
+if sys.version_info < (3, 12):
+    # zmievsa/cadwyn#262
+    # Setting this to "Any" doesn't have any impact on the API as it has to be 
parsed as valid JSON regardless
+    JsonValue = Any  # type: ignore [misc]
+
+
 # TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the 
URL and just use the info in
 # the token
 @router.post(
     "/{dag_id}/{run_id}/{task_id}/{key}",
     status_code=status.HTTP_201_CREATED,
-    responses={
-        status.HTTP_400_BAD_REQUEST: {"description": "Invalid request body"},
-    },
 )
 def set_xcom(
     dag_id: str,
diff --git a/dev/datamodel_code_formatter.py b/dev/datamodel_code_formatter.py
index aa13d8b09a4..1f726073766 100644
--- a/dev/datamodel_code_formatter.py
+++ b/dev/datamodel_code_formatter.py
@@ -22,6 +22,7 @@ from pathlib import Path
 
 import libcst as cst
 from datamodel_code_generator.format import CustomCodeFormatter
+from libcst.helpers import parse_template_statement
 
 
 def license_text() -> str:
@@ -53,6 +54,34 @@ class CodeFormatter(CustomCodeFormatter):
                     return cst.RemoveFromParent()
                 return super().leave_ClassDef(original_node, updated_node)
 
+        class VersionConstInjtector(cst.CSTTransformer):
+            handled = False
+
+            def __init__(self, api_version: str) -> None:
+                self.api_version = api_version
+                super().__init__()
+
+            def leave_ImportFrom(
+                self, original_node: cst.ImportFrom, updated_node: 
cst.ImportFrom
+            ) -> cst.BaseSmallStatement | 
cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
+                # Ensure we have `from typing import Final`
+                if original_node.module and original_node.module.value == 
"typing":
+                    new_names = updated_node.names + 
(cst.ImportAlias(name=cst.Name("Final")),)  # type: ignore[operator]
+                    return updated_node.with_changes(names=new_names)
+                return super().leave_ImportFrom(original_node, updated_node)
+
+            def leave_ClassDef(self, original_node: cst.ClassDef, 
updated_node: cst.ClassDef):
+                if self.handled:
+                    return super().leave_ClassDef(original_node, updated_node)
+
+                self.handled = True
+
+                const = parse_template_statement(
+                    "API_VERSION: Final[str] = {api_version}",
+                    api_version=cst.SimpleString(f'"{self.api_version}"'),
+                )
+                return cst.FlattenSentinel([const, updated_node])
+
         # Remove Task class that represent a tuple of (task_id, map_index)
         # for `TISkippedDownstreamTasksStatePayload`
         class ModifyTasksAnnotation(cst.CSTTransformer):
@@ -102,6 +131,8 @@ class CodeFormatter(CustomCodeFormatter):
 
         source_tree = cst.parse_module(code)
         modified_tree = source_tree.visit(JsonValueNodeRemover())
+        if api_version := self.formatter_kwargs.get("api_version"):
+            modified_tree = 
modified_tree.visit(VersionConstInjtector(api_version))
         modified_tree = modified_tree.visit(ModifyTasksAnnotation())
         code = modified_tree.code
 
diff --git a/hatch_build.py b/hatch_build.py
index c76e0b38961..51a172b335b 100644
--- a/hatch_build.py
+++ b/hatch_build.py
@@ -192,6 +192,7 @@ DEPENDENCIES = [
     # Blinker use for signals in Flask, this is an optional dependency in 
Flask 2.2 and lower.
     # In Flask 2.3 it becomes a mandatory dependency, and flask signals are 
always available.
     "blinker>=1.6.2",
+    "cadwyn>=5.1.2",
     "colorlog>=6.8.2",
     "configupdater>=3.1.1",
     "cron-descriptor>=1.2.24",
@@ -204,7 +205,7 @@ DEPENDENCIES = [
     'eval-type-backport>=0.2.0;python_version<"3.10"',
     # 0.115.10 fastapi was a bad release that broke our API's and static 
checks.
     # Related fastapi issue here: 
https://github.com/fastapi/fastapi/discussions/13431
-    "fastapi[standard]>=0.112.2,!=0.115.10",
+    "fastapi[standard]>=0.112.4,!=0.115.10",
     "flask-caching>=2.0.0",
     # Flask-Session 0.6 add new arguments into the SqlAlchemySessionInterface 
constructor as well as
     # all parameters now are mandatory which make 
AirflowDatabaseSessionInterface incompatible with this version.
diff --git a/task-sdk/dev/generate_models.py b/task-sdk/dev/generate_models.py
index 60a17333de9..2c96e48d80f 100644
--- a/task-sdk/dev/generate_models.py
+++ b/task-sdk/dev/generate_models.py
@@ -16,12 +16,11 @@
 # under the License.
 from __future__ import annotations
 
-import json
 import os
 import sys
 from pathlib import Path
-from typing import TYPE_CHECKING
 
+import httpx
 from datamodel_code_generator import (
     DataModelType,
     DatetimeClassType,
@@ -43,13 +42,10 @@ from common_precommit_utils import (
 
 sys.path.insert(0, str(AIRFLOW_SOURCES_ROOT_PATH))  # make sure setup is 
imported from Airflow
 
-from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
+from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
 
 task_sdk_root = Path(__file__).parents[1]
 
-if TYPE_CHECKING:
-    from fastapi import FastAPI
-
 
 def load_config():
     try:
@@ -80,20 +76,21 @@ def load_config():
     return cfg
 
 
-def generate_file(app: FastAPI):
-    # The persisted openapi spec will list all endpoints (public and ui), this
-    # is used for code generation.
-    for route in app.routes:
-        if getattr(route, "name") == "webapp":
-            continue
-        route.__setattr__("include_in_schema", True)
+def generate_file():
+    app = InProcessExecutionAPI()
+
+    latest_version = app.app.versions.version_values[-1]
+    client = httpx.Client(transport=app.transport)
+    openapi_schema = (
+        
client.get(f"http://localhost/openapi.json?version={latest_version}";).raise_for_status().text
+    )
 
     os.chdir(task_sdk_root)
 
-    openapi_schema = json.dumps(app.openapi())
     args = load_config()
     args["input_filename"] = args.pop("url")
+    args["custom_formatters_kwargs"] = {"api_version": latest_version}
     generate_models(openapi_schema, **args)
 
 
-generate_file(create_task_execution_api_app())
+generate_file()
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 0b3c4971fe5..9108cbc08f8 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -35,6 +35,7 @@ from uuid6 import uuid7
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TIRuntimeCheckPayload
 from airflow.sdk import __version__
 from airflow.sdk.api.datamodels._generated import (
+    API_VERSION,
     AssetEventsResponse,
     AssetResponse,
     ConnectionResponse,
@@ -256,7 +257,7 @@ class VariableOperations:
 
     def set(self, key: str, value: str | None, description: str | None = None):
         """Set an Airflow Variable via the API server."""
-        body = VariablePostBody(value=value, description=description)
+        body = VariablePostBody(val=value, description=description)
         self.client.put(f"variables/{key}", content=body.model_dump_json())
         # Any error from the server will anyway be propagated down to the 
supervisor,
         # so we choose to send a generic response to the supervisor over the 
server response to
@@ -513,7 +514,10 @@ class Client(httpx.Client):
         pyver = f"{'.'.join(map(str, sys.version_info[:3]))}"
         super().__init__(
             auth=auth,
-            headers={"user-agent": f"apache-airflow-task-sdk/{__version__} 
(Python/{pyver})"},
+            headers={
+                "user-agent": f"apache-airflow-task-sdk/{__version__} 
(Python/{pyver})",
+                "airflow-api-version": API_VERSION,
+            },
             event_hooks={"response": [raise_on_4xx_5xx], "request": 
[add_correlation_id]},
             **kwargs,
         )
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 8725cdc04cf..9b99a3d57c8 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -22,11 +22,13 @@ from __future__ import annotations
 
 from datetime import datetime, timedelta
 from enum import Enum
-from typing import Annotated, Any, Literal
+from typing import Annotated, Any, Final, Literal
 from uuid import UUID
 
 from pydantic import BaseModel, ConfigDict, Field, JsonValue
 
+API_VERSION: Final[str] = "2025-03-19"
+
 
 class AssetProfile(BaseModel):
     """
@@ -95,14 +97,6 @@ class DagRunAssetReference(BaseModel):
 
 
 class DagRunState(str, Enum):
-    """
-    All possible states that a DagRun can be in.
-
-    These are "shared" with TaskInstanceState in some parts of the code,
-    so please ensure that their values always match the ones with the
-    same name in TaskInstanceState.
-    """
-
     QUEUED = "queued"
     RUNNING = "running"
     SUCCESS = "success"
@@ -118,10 +112,6 @@ class DagRunStateResponse(BaseModel):
 
 
 class DagRunType(str, Enum):
-    """
-    Class with DagRun types.
-    """
-
     BACKFILL = "backfill"
     SCHEDULED = "scheduled"
     MANUAL = "manual"
@@ -129,10 +119,6 @@ class DagRunType(str, Enum):
 
 
 class IntermediateTIState(str, Enum):
-    """
-    States that a Task Instance can be in that indicate it is not yet in a 
terminal or running state.
-    """
-
     SCHEDULED = "scheduled"
     QUEUED = "queued"
     RESTARTING = "restarting"
@@ -258,10 +244,6 @@ class TITargetStatePayload(BaseModel):
 
 
 class TerminalStateNonSuccess(str, Enum):
-    """
-    TaskInstance states that can be reported without extra information.
-    """
-
     FAILED = "failed"
     SKIPPED = "skipped"
     REMOVED = "removed"
@@ -295,7 +277,7 @@ class VariablePostBody(BaseModel):
     model_config = ConfigDict(
         extra="forbid",
     )
-    value: Annotated[str | None, Field(title="Value")] = None
+    val: Annotated[str | None, Field(title="Val")] = None
     description: Annotated[str | None, Field(title="Description")] = None
 
 
@@ -304,6 +286,9 @@ class VariableResponse(BaseModel):
     Variable schema for responses with fields that are needed for Runtime.
     """
 
+    model_config = ConfigDict(
+        extra="forbid",
+    )
     key: Annotated[str, Field(title="Key")]
     value: Annotated[str | None, Field(title="Value")] = None
 
diff --git a/tests/api_fastapi/execution_api/conftest.py 
b/tests/api_fastapi/execution_api/conftest.py
index 7ff9e31c26d..14ea02263f9 100644
--- a/tests/api_fastapi/execution_api/conftest.py
+++ b/tests/api_fastapi/execution_api/conftest.py
@@ -29,7 +29,11 @@ from airflow.api_fastapi.execution_api.app import lifespan
 @pytest.fixture
 def client(request: pytest.FixtureRequest):
     app = cached_app(apps="execution")
-    with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
+
+    # By specifying a "far-future" date, this will make the tests always run 
against the latest version
+    with TestClient(
+        app, headers={"Authorization": "Bearer fake", "Airflow-API-Version": 
"9999-12-31"}
+    ) as client:
         auth = AsyncMock(spec=JWTValidator)
         auth.avalidated_claims.return_value = {"sub": 
"edb09971-4e0e-4221-ad3f-800852d38085"}
 
diff --git a/tests/api_fastapi/execution_api/routes/test_health.py 
b/tests/api_fastapi/execution_api/routes/test_health.py
index 04ada69e538..dc13f9880a8 100644
--- a/tests/api_fastapi/execution_api/routes/test_health.py
+++ b/tests/api_fastapi/execution_api/routes/test_health.py
@@ -21,8 +21,8 @@ import pytest
 pytestmark = pytest.mark.db_test
 
 
-def test_health(test_client):
-    response = test_client.get("/execution/health")
+def test_health(client):
+    response = client.get("/execution/health")
 
     assert response.status_code == 200
     assert response.json() == {"status": "healthy"}
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 500c730a783..6cd74f856ab 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -23,15 +23,11 @@ from unittest import mock
 
 import pytest
 import uuid6
-from fastapi import FastAPI
-from fastapi.routing import Mount
 from sqlalchemy import select, update
 from sqlalchemy.exc import SQLAlchemyError
 
-from airflow.api_fastapi.app import purge_cached_app
 from airflow.api_fastapi.auth.tokens import JWTValidator
 from airflow.api_fastapi.execution_api.app import lifespan
-from airflow.api_fastapi.execution_api.routes.task_instances import router as 
task_instance_router
 from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
 from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, 
AssetModel
 from airflow.models.taskinstance import TaskInstance
@@ -63,43 +59,54 @@ def _create_asset_aliases(session, num: int = 2) -> None:
 
 
 @pytest.fixture
-def add_foo_test_route(client):
-    @task_instance_router.get("/{task_instance_id}/foo")
-    def foo(task_instance_id: str):
-        return {"hi": task_instance_id}
+def client_with_extra_route(): ...
 
-    app: FastAPI = client.app
 
-    last_route = app.routes[-1]
-    assert isinstance(last_route, Mount)
-    assert isinstance(last_route.app, FastAPI)
-    # Re-add it, so it gets the new route we've added
-    last_route.app.include_router(task_instance_router, 
prefix="/task-instances")
+def test_id_matches_sub_claim(client, session, create_task_instance):
+    # Test that this is validated at the router level, so we don't have to 
test it in each component
+    # We validate it is set correctly, and test it once
 
-    yield
+    ti = create_task_instance(
+        task_id="test_ti_run_state_conflict_if_not_queued",
+        state="queued",
+    )
+    session.commit()
 
-    purge_cached_app()
+    validator = mock.AsyncMock(spec=JWTValidator)
+    claims = {"sub": ti.id}
 
+    def side_effect(cred, validators):
+        if not validators:
+            return claims
+        if validators["sub"]["value"] != ti.id:
+            raise RuntimeError("Fake auth denied")
+        return claims
 
-@pytest.mark.usefixtures("add_foo_test_route")
-def test_id_matches_sub_claim(client):
-    # Test that this is validated at the router level, so we don't have to 
test it in each component
-    validator = mock.AsyncMock(spec=JWTValidator)
-    claims = {"sub": "edb09971-4e0e-4221-ad3f-800852d38085"}
-    validator.avalidated_claims.side_effect = [claims, RuntimeError("Fail for 
test")]
+    # validator.avalidated_claims.side_effect = [{}, RuntimeError("fail for 
tests"), claims, claims]
+    validator.avalidated_claims.side_effect = side_effect
 
     lifespan.registry.register_value(JWTValidator, validator)
 
-    resp = client.get(f"/execution/task-instances/{claims['sub']}/foo")
-    assert resp.status_code == 200
-
-    validator.avalidated_claims.assert_awaited_once()
+    payload = {
+        "state": "running",
+        "hostname": "random-hostname",
+        "unixname": "random-unixname",
+        "pid": 100,
+        "start_date": "2024-10-31T12:00:00Z",
+    }
 
-    resp = 
client.get("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/foo")
+    resp = 
client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run",
 json=payload)
     assert resp.status_code == 403
     validator.avalidated_claims.assert_called_with(
         mock.ANY, {"sub": {"essential": True, "value": 
"9c230b40-da03-451d-8bd7-be30471be383"}}
     )
+    validator.avalidated_claims.reset_mock()
+
+    resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload)
+
+    assert resp.status_code == 200, resp.json()
+
+    validator.avalidated_claims.assert_awaited()
 
 
 class TestTIRunState:
diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py 
b/tests/api_fastapi/execution_api/routes/test_variables.py
index 01d5a77f7be..abc0ec4d5c0 100644
--- a/tests/api_fastapi/execution_api/routes/test_variables.py
+++ b/tests/api_fastapi/execution_api/routes/test_variables.py
@@ -129,7 +129,7 @@ class TestPutVariable:
             f"/execution/variables/{key}",
             json=payload,
         )
-        assert response.status_code == 201
+        assert response.status_code == 201, response.json()
         assert response.json()["message"] == "Variable successfully set"
 
         var_from_db = session.query(Variable).where(Variable.key == 
"var_create").first()
diff --git a/tests/api_fastapi/execution_api/test_app.py 
b/tests/api_fastapi/execution_api/test_app.py
index b2129382f6a..6eeeec393a8 100644
--- a/tests/api_fastapi/execution_api/test_app.py
+++ b/tests/api_fastapi/execution_api/test_app.py
@@ -26,7 +26,7 @@ pytestmark = pytest.mark.db_test
 def test_custom_openapi_includes_extra_schemas(client):
     """Test to ensure that extra schemas are correctly included in the OpenAPI 
schema."""
 
-    response = client.get("/execution/openapi.json")
+    response = client.get("/execution/openapi.json?version=2025-03-19")
     assert response.status_code == 200
 
     openapi_schema = response.json()

Reply via email to