This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f7270c8a202 Extend OpenAPI schema with extra models for Task SDK
(#44076)
f7270c8a202 is described below
commit f7270c8a2026b8da07590623560ba58e9da38d7f
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Nov 15 22:06:29 2024 +0000
Extend OpenAPI schema with extra models for Task SDK (#44076)
- Introduced `custom_openapi` to extend OpenAPI schema with additional
models.
- Added `TaskInstance` model for inclusion in OpenAPI schema, specifically
for Task SDK
Reference:
https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
---
airflow/api_fastapi/execution_api/app.py | 45 ++++++++++++++++++++--
.../execution_api/datamodels/taskinstance.py | 15 ++++++++
.../src/airflow/sdk/api/datamodels/_generated.py | 16 ++++++--
.../src/airflow/sdk/api/datamodels/activities.py | 2 +-
task_sdk/src/airflow/sdk/execution_time/comms.py | 3 +-
.../src/airflow/sdk/execution_time/supervisor.py | 4 +-
.../src/airflow/sdk/execution_time/task_runner.py | 3 +-
task_sdk/tests/execution_time/test_supervisor.py | 2 +-
task_sdk/tests/execution_time/test_task_runner.py | 2 +-
.../api_fastapi/execution_api/test_app.py | 21 +++++-----
10 files changed, 88 insertions(+), 25 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/app.py
b/airflow/api_fastapi/execution_api/app.py
index 1751b61bcd5..e019e8f14f3 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -20,6 +20,7 @@ from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI
+from fastapi.openapi.utils import get_openapi
@asynccontextmanager
@@ -34,11 +35,49 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
from airflow.api_fastapi.execution_api.routes import execution_api_router
# TODO: Add versioning to the API
- task_exec_api_app = FastAPI(
+ app = FastAPI(
title="Airflow Task Execution API",
description="The private Airflow Task Execution API.",
lifespan=lifespan,
)
- task_exec_api_app.include_router(execution_api_router)
- return task_exec_api_app
+ def custom_openapi() -> dict:
+ """
+ Customize the OpenAPI schema to include additional schemas not tied to
specific endpoints.
+
+ This is particularly useful for client SDKs that require models for
types
+ not directly exposed in any endpoint's request or response schema.
+
+ References:
+ -
https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
+ """
+ if app.openapi_schema:
+ return app.openapi_schema
+ openapi_schema = get_openapi(
+ title=app.title,
+ description=app.description,
+ version=app.version,
+ routes=app.routes,
+ )
+
+ extra_schemas = get_extra_schemas()
+ for schema_name, schema in extra_schemas.items():
+ if schema_name not in openapi_schema["components"]["schemas"]:
+ openapi_schema["components"]["schemas"][schema_name] = schema
+
+ app.openapi_schema = openapi_schema
+ return app.openapi_schema
+
+ app.openapi = custom_openapi # type: ignore[method-assign]
+
+ app.include_router(execution_api_router)
+ return app
+
+
+def get_extra_schemas() -> dict[str, dict]:
+ """Get all the extra schemas that are not part of the main FastAPI app."""
+ from airflow.api_fastapi.execution_api.datamodels import taskinstance
+
+ return {
+ "TaskInstance": taskinstance.TaskInstance.model_json_schema(),
+ }
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index db63dc3a8db..07066eb5a5c 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+import uuid
from typing import Annotated, Literal, Union
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema
@@ -97,3 +98,17 @@ class TIHeartbeatInfo(BaseModel):
hostname: str
pid: int
+
+
+# This model is not used in the API, but it is included in generated OpenAPI
schema
+# for use in the client SDKs.
+class TaskInstance(BaseModel):
+ """Schema for TaskInstance model with minimal required fields needed for
Runtime."""
+
+ id: uuid.UUID
+
+ task_id: str
+ dag_id: str
+ run_id: str
+ try_number: int
+ map_index: int | None = None
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index e921bee4bc2..c1d10f74d4a 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -24,6 +24,7 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal
+from uuid import UUID
from pydantic import BaseModel, Field
@@ -45,10 +46,9 @@ class ConnectionResponse(BaseModel):
class IntermediateTIState(str, Enum):
"""
- States that a Task Instance can be in that indicate it is not yet in a
terminal or running state
+ States that a Task Instance can be in that indicate it is not yet in a
terminal or running state.
"""
- REMOVED = "removed"
SCHEDULED = "scheduled"
QUEUED = "queued"
RESTARTING = "restarting"
@@ -89,12 +89,13 @@ class TITargetStatePayload(BaseModel):
class TerminalTIState(str, Enum):
"""
- States that a Task Instance can be in that indicate it has reached a
terminal state
+ States that a Task Instance can be in that indicate it has reached a
terminal state.
"""
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
+ REMOVED = "removed"
class ValidationError(BaseModel):
@@ -121,6 +122,15 @@ class XComResponse(BaseModel):
value: Annotated[Any, Field(title="Value")]
+class TaskInstance(BaseModel):
+ id: Annotated[UUID, Field(title="Id")]
+ task_id: Annotated[str, Field(title="Task Id")]
+ dag_id: Annotated[str, Field(title="Dag Id")]
+ run_id: Annotated[str, Field(title="Run Id")]
+ try_number: Annotated[int, Field(title="Try Number")]
+ map_index: Annotated[int | None, Field(title="Map Index")] = None
+
+
class HTTPValidationError(BaseModel):
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] =
None
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/activities.py
b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
index 04f2b389d5d..30bf41f6a28 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/activities.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
@@ -21,7 +21,7 @@ import os
from pydantic import BaseModel
-from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.api.datamodels._generated import TaskInstance
class ExecuteTaskActivity(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index a78fbb3e33b..07b260a417d 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -47,8 +47,7 @@ from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, ConfigDict, Field
-from airflow.sdk.api.datamodels._generated import TerminalTIState # noqa:
TCH001
-from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001
+from airflow.sdk.api.datamodels._generated import TaskInstance,
TerminalTIState # noqa: TCH001
class StartupDetails(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index c05c6138f96..6ecd8ff5698 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -43,7 +43,7 @@ import structlog
from pydantic import TypeAdapter
from airflow.sdk.api.client import Client
-from airflow.sdk.api.datamodels._generated import TerminalTIState
+from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResponse,
GetConnection,
@@ -55,8 +55,6 @@ if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger
from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
- from airflow.sdk.api.datamodels.ti import TaskInstance
-
__all__ = ["WatchedSubprocess", "supervise"]
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c952207bca5..a6d7569382b 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -29,7 +29,8 @@ import structlog
from pydantic import ConfigDict, TypeAdapter
from airflow.sdk import BaseOperator
-from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance,
ToSupervisor, ToTask
+from airflow.sdk.api.datamodels._generated import TaskInstance
+from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor,
ToTask
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 7a712d1cc0a..428ade1c35a 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -29,7 +29,7 @@ import structlog
import structlog.testing
from airflow.sdk.api import client as sdk_client
-from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.utils import timezone as tz
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index c634ba1255f..40c112170c6 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -24,7 +24,7 @@ from socket import socketpair
import pytest
from uuid6 import uuid7
-from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/ti.py
b/tests/api_fastapi/execution_api/test_app.py
similarity index 59%
rename from task_sdk/src/airflow/sdk/api/datamodels/ti.py
rename to tests/api_fastapi/execution_api/test_app.py
index ce9e1e870ae..ccd8b4c8db9 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/ti.py
+++ b/tests/api_fastapi/execution_api/test_app.py
@@ -14,19 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from __future__ import annotations
-import uuid
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TaskInstance
+
+
+def test_custom_openapi_includes_extra_schemas(client):
+ """Test to ensure that extra schemas are correctly included in the OpenAPI
schema."""
-from pydantic import BaseModel
+ response = client.get("/execution/openapi.json")
+ assert response.status_code == 200
+ openapi_schema = response.json()
-class TaskInstance(BaseModel):
- id: uuid.UUID
+ assert "TaskInstance" in openapi_schema["components"]["schemas"]
+ schema = openapi_schema["components"]["schemas"]["TaskInstance"]
- task_id: str
- dag_id: str
- run_id: str
- try_number: int
- map_index: int | None = None
+ assert schema == TaskInstance.model_json_schema()