This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f7270c8a202 Extend OpenAPI schema with extra models for Task SDK 
(#44076)
f7270c8a202 is described below

commit f7270c8a2026b8da07590623560ba58e9da38d7f
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Nov 15 22:06:29 2024 +0000

    Extend OpenAPI schema with extra models for Task SDK (#44076)
    
    - Introduced `custom_openapi` to extend OpenAPI schema with additional 
models.
    - Added `TaskInstance` model for inclusion in OpenAPI schema, specifically 
for Task SDK
    
    Reference: 
https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
---
 airflow/api_fastapi/execution_api/app.py           | 45 ++++++++++++++++++++--
 .../execution_api/datamodels/taskinstance.py       | 15 ++++++++
 .../src/airflow/sdk/api/datamodels/_generated.py   | 16 ++++++--
 .../src/airflow/sdk/api/datamodels/activities.py   |  2 +-
 task_sdk/src/airflow/sdk/execution_time/comms.py   |  3 +-
 .../src/airflow/sdk/execution_time/supervisor.py   |  4 +-
 .../src/airflow/sdk/execution_time/task_runner.py  |  3 +-
 task_sdk/tests/execution_time/test_supervisor.py   |  2 +-
 task_sdk/tests/execution_time/test_task_runner.py  |  2 +-
 .../api_fastapi/execution_api/test_app.py          | 21 +++++-----
 10 files changed, 88 insertions(+), 25 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/app.py 
b/airflow/api_fastapi/execution_api/app.py
index 1751b61bcd5..e019e8f14f3 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 from contextlib import asynccontextmanager
 
 from fastapi import FastAPI
+from fastapi.openapi.utils import get_openapi
 
 
 @asynccontextmanager
@@ -34,11 +35,49 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
     from airflow.api_fastapi.execution_api.routes import execution_api_router
 
     # TODO: Add versioning to the API
-    task_exec_api_app = FastAPI(
+    app = FastAPI(
         title="Airflow Task Execution API",
         description="The private Airflow Task Execution API.",
         lifespan=lifespan,
     )
 
-    task_exec_api_app.include_router(execution_api_router)
-    return task_exec_api_app
+    def custom_openapi() -> dict:
+        """
+        Customize the OpenAPI schema to include additional schemas not tied to 
specific endpoints.
+
+        This is particularly useful for client SDKs that require models for 
types
+        not directly exposed in any endpoint's request or response schema.
+
+        References:
+            - 
https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
+        """
+        if app.openapi_schema:
+            return app.openapi_schema
+        openapi_schema = get_openapi(
+            title=app.title,
+            description=app.description,
+            version=app.version,
+            routes=app.routes,
+        )
+
+        extra_schemas = get_extra_schemas()
+        for schema_name, schema in extra_schemas.items():
+            if schema_name not in openapi_schema["components"]["schemas"]:
+                openapi_schema["components"]["schemas"][schema_name] = schema
+
+        app.openapi_schema = openapi_schema
+        return app.openapi_schema
+
+    app.openapi = custom_openapi  # type: ignore[method-assign]
+
+    app.include_router(execution_api_router)
+    return app
+
+
+def get_extra_schemas() -> dict[str, dict]:
+    """Get all the extra schemas that are not part of the main FastAPI app."""
+    from airflow.api_fastapi.execution_api.datamodels import taskinstance
+
+    return {
+        "TaskInstance": taskinstance.TaskInstance.model_json_schema(),
+    }
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index db63dc3a8db..07066eb5a5c 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import uuid
 from typing import Annotated, Literal, Union
 
 from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema
@@ -97,3 +98,17 @@ class TIHeartbeatInfo(BaseModel):
 
     hostname: str
     pid: int
+
+
+# This model is not used in the API, but it is included in generated OpenAPI 
schema
+# for use in the client SDKs.
+class TaskInstance(BaseModel):
+    """Schema for TaskInstance model with minimal required fields needed for 
Runtime."""
+
+    id: uuid.UUID
+
+    task_id: str
+    dag_id: str
+    run_id: str
+    try_number: int
+    map_index: int | None = None
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index e921bee4bc2..c1d10f74d4a 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -24,6 +24,7 @@ from __future__ import annotations
 from datetime import datetime
 from enum import Enum
 from typing import Annotated, Any, Literal
+from uuid import UUID
 
 from pydantic import BaseModel, Field
 
@@ -45,10 +46,9 @@ class ConnectionResponse(BaseModel):
 
 class IntermediateTIState(str, Enum):
     """
-    States that a Task Instance can be in that indicate it is not yet in a 
terminal or running state
+    States that a Task Instance can be in that indicate it is not yet in a 
terminal or running state.
     """
 
-    REMOVED = "removed"
     SCHEDULED = "scheduled"
     QUEUED = "queued"
     RESTARTING = "restarting"
@@ -89,12 +89,13 @@ class TITargetStatePayload(BaseModel):
 
 class TerminalTIState(str, Enum):
     """
-    States that a Task Instance can be in that indicate it has reached a 
terminal state
+    States that a Task Instance can be in that indicate it has reached a 
terminal state.
     """
 
     SUCCESS = "success"
     FAILED = "failed"
     SKIPPED = "skipped"
+    REMOVED = "removed"
 
 
 class ValidationError(BaseModel):
@@ -121,6 +122,15 @@ class XComResponse(BaseModel):
     value: Annotated[Any, Field(title="Value")]
 
 
+class TaskInstance(BaseModel):
+    id: Annotated[UUID, Field(title="Id")]
+    task_id: Annotated[str, Field(title="Task Id")]
+    dag_id: Annotated[str, Field(title="Dag Id")]
+    run_id: Annotated[str, Field(title="Run Id")]
+    try_number: Annotated[int, Field(title="Try Number")]
+    map_index: Annotated[int | None, Field(title="Map Index")] = None
+
+
 class HTTPValidationError(BaseModel):
     detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = 
None
 
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/activities.py 
b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
index 04f2b389d5d..30bf41f6a28 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/activities.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
@@ -21,7 +21,7 @@ import os
 
 from pydantic import BaseModel
 
-from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.api.datamodels._generated import TaskInstance
 
 
 class ExecuteTaskActivity(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index a78fbb3e33b..07b260a417d 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -47,8 +47,7 @@ from typing import Annotated, Any, Literal, Union
 
 from pydantic import BaseModel, ConfigDict, Field
 
-from airflow.sdk.api.datamodels._generated import TerminalTIState  # noqa: 
TCH001
-from airflow.sdk.api.datamodels.ti import TaskInstance  # noqa: TCH001
+from airflow.sdk.api.datamodels._generated import TaskInstance, 
TerminalTIState  # noqa: TCH001
 
 
 class StartupDetails(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index c05c6138f96..6ecd8ff5698 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -43,7 +43,7 @@ import structlog
 from pydantic import TypeAdapter
 
 from airflow.sdk.api.client import Client
-from airflow.sdk.api.datamodels._generated import TerminalTIState
+from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
 from airflow.sdk.execution_time.comms import (
     ConnectionResponse,
     GetConnection,
@@ -55,8 +55,6 @@ if TYPE_CHECKING:
     from structlog.typing import FilteringBoundLogger
 
     from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
-    from airflow.sdk.api.datamodels.ti import TaskInstance
-
 
 __all__ = ["WatchedSubprocess", "supervise"]
 
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c952207bca5..a6d7569382b 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -29,7 +29,8 @@ import structlog
 from pydantic import ConfigDict, TypeAdapter
 
 from airflow.sdk import BaseOperator
-from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, 
ToSupervisor, ToTask
+from airflow.sdk.api.datamodels._generated import TaskInstance
+from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor, 
ToTask
 
 if TYPE_CHECKING:
     from structlog.typing import FilteringBoundLogger as Logger
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 7a712d1cc0a..428ade1c35a 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -29,7 +29,7 @@ import structlog
 import structlog.testing
 
 from airflow.sdk.api import client as sdk_client
-from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.api.datamodels._generated import TaskInstance
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess
 from airflow.utils import timezone as tz
 
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index c634ba1255f..40c112170c6 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -24,7 +24,7 @@ from socket import socketpair
 import pytest
 from uuid6 import uuid7
 
-from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.api.datamodels._generated import TaskInstance
 from airflow.sdk.execution_time.comms import StartupDetails
 from airflow.sdk.execution_time.task_runner import CommsDecoder, parse
 
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/ti.py 
b/tests/api_fastapi/execution_api/test_app.py
similarity index 59%
rename from task_sdk/src/airflow/sdk/api/datamodels/ti.py
rename to tests/api_fastapi/execution_api/test_app.py
index ce9e1e870ae..ccd8b4c8db9 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/ti.py
+++ b/tests/api_fastapi/execution_api/test_app.py
@@ -14,19 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 from __future__ import annotations
 
-import uuid
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TaskInstance
+
+
+def test_custom_openapi_includes_extra_schemas(client):
+    """Test to ensure that extra schemas are correctly included in the OpenAPI 
schema."""
 
-from pydantic import BaseModel
+    response = client.get("/execution/openapi.json")
+    assert response.status_code == 200
 
+    openapi_schema = response.json()
 
-class TaskInstance(BaseModel):
-    id: uuid.UUID
+    assert "TaskInstance" in openapi_schema["components"]["schemas"]
+    schema = openapi_schema["components"]["schemas"]["TaskInstance"]
 
-    task_id: str
-    dag_id: str
-    run_id: str
-    try_number: int
-    map_index: int | None = None
+    assert schema == TaskInstance.model_json_schema()

Reply via email to