This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 872427c2534 Version the worker-bound TaskInstance fields in the
execution API schema (#68390)
872427c2534 is described below
commit 872427c25340331ceb9b053af4741ee906b147f1
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Jun 11 19:39:21 2026 +0100
Version the worker-bound TaskInstance fields in the execution API schema
(#68390)
The supervisor's task routing (added in #65958) reads queue, pool_slots and
priority_weight from the TaskInstance it receives, but those fields lived in
a TaskInstanceDTO duplicated between airflow-core and task-sdk and kept in
sync by an AST-comparison prek hook, outside the versioned execution API
schema the Task SDK datamodels are generated from.
Add the three fields to the execution API TaskInstance schema (with a Cadwyn
version change) and regenerate the SDK datamodels, so the generated
TaskInstance carries everything the supervisor needs. Re-point
StartupDetails
and the coordinator interfaces at the generated model, fold the core
TaskInstanceDTO into a subclass of the schema model that only adds the
executor-side fields, and drop the duplicated task-sdk DTO and the sync
hook.
The unused parent_context_carrier field (no readers or writers) is removed.
* Keep pool_slots and priority_weight executor-side; only queue joins the
schema
The supervisor only routes on queue; pool_slots and priority_weight are
executor concerns (queued-workload priority ordering and edge concurrency
slot accounting) that the worker never reads. Keep them on the executor
DTO, serialized as before, and add only queue to the worker-facing schema.
The generated TaskInstance no longer accepts pool_slots/priority_weight
kwargs, its queue is Optional in the generated typing, and the DTO now
validates the inherited hostname field that the ECS adoption test's mock
did not set.
---
.pre-commit-config.yaml | 6 -
airflow-core/docs/core-concepts/executor/index.rst | 6 +-
.../execution_api/datamodels/taskinstance.py | 4 +
.../api_fastapi/execution_api/versions/__init__.py | 2 +
.../execution_api/versions/v2026_06_30.py | 13 +-
.../src/airflow/executors/workloads/task.py | 32 +--
.../tests/unit/executors/test_workloads.py | 37 ++++
devel-common/src/tests_common/pytest_plugin.py | 18 +-
.../org/apache/airflow/sdk/execution/TaskTest.kt | 4 +-
.../amazon/aws/executors/ecs/test_ecs_executor.py | 1 +
.../edge3/worker_api/v2-edge-generated.yaml | 48 +++--
scripts/ci/prek/check_task_instance_dto_sync.py | 125 -----------
.../src/airflow/sdk/api/datamodels/_generated.py | 1 +
.../src/airflow/sdk/coordinators/_subprocess.py | 9 +-
.../sdk/coordinators/executable/coordinator.py | 4 +-
.../airflow/sdk/coordinators/java/coordinator.py | 4 +-
task-sdk/src/airflow/sdk/execution_time/comms.py | 7 +-
.../src/airflow/sdk/execution_time/coordinator.py | 6 +-
.../airflow/sdk/execution_time/schema/schema.json | 236 +++++++--------------
.../src/airflow/sdk/execution_time/supervisor.py | 9 +-
.../sdk/execution_time/workloads/__init__.py | 23 --
.../airflow/sdk/execution_time/workloads/task.py | 53 -----
.../coordinators/executable/test_coordinator.py | 8 +-
.../task_sdk/coordinators/java/test_coordinator.py | 8 +-
.../tests/task_sdk/coordinators/test_subprocess.py | 8 +-
.../task_sdk/execution_time/test_supervisor.py | 74 ++-----
.../task_sdk/execution_time/test_task_runner.py | 63 ++----
27 files changed, 240 insertions(+), 569 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 917a262d5ae..4d20d7784a0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -539,12 +539,6 @@ repos:
language: python
pass_filenames: false
files:
^dev/registry/registry_tools/types\.py$|^registry/src/_data/types\.json$
- - id: check-task-instance-dto-sync
- name: Check BaseTaskInstanceDTO duplicate is in sync between core and
task-sdk
- entry: ./scripts/ci/prek/check_task_instance_dto_sync.py
- language: python
- pass_filenames: false
- files:
^airflow-core/src/airflow/executors/workloads/task\.py$|^task-sdk/src/airflow/sdk/execution_time/workloads/task\.py$
- id: ruff
name: Run 'ruff' for extremely fast Python linting
description: "Run 'ruff' for extremely fast Python linting"
diff --git a/airflow-core/docs/core-concepts/executor/index.rst
b/airflow-core/docs/core-concepts/executor/index.rst
index 9420c55d84e..cf21daafd3f 100644
--- a/airflow-core/docs/core-concepts/executor/index.rst
+++ b/airflow-core/docs/core-concepts/executor/index.rst
@@ -271,20 +271,18 @@ Example:
ExecuteTask(
token="mock",
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"),
task_id="mock",
dag_id="mock",
run_id="mock",
try_number=1,
+ dag_version_id=UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"),
map_index=-1,
pool_slots=1,
queue="default",
priority_weight=1,
executor_config=None,
- parent_context_carrier=None,
- context_carrier=None,
- queued_dttm=None,
),
dag_rel_path=PurePosixPath("mock.py"),
bundle_info=BundleInfo(name="n/a", version="no matter"),
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index c7f6b1ee9f8..fe7f5a5ce05 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -286,6 +286,10 @@ class TaskInstance(BaseModel):
map_index: int = -1
hostname: str | None = None
context_carrier: dict | None = None
+ # The supervisor routes tasks to a coordinator by queue. The default keeps
+ # hand-built instances (tests, dry runs) valid; the executor workload
+ # always sends the real value.
+ queue: str = "default"
class AssetReferenceAssetEventDagRun(StrictBaseModel):
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 9e4d486aa30..656bb8dce10 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -49,6 +49,7 @@ from airflow.api_fastapi.execution_api.versions.v2026_06_16
import (
from airflow.api_fastapi.execution_api.versions.v2026_06_30 import (
AddAwaitingInputStatePayload,
AddConnectionTestEndpoint,
+ AddTaskInstanceQueueField,
AddVariableKeysEndpoint,
)
@@ -59,6 +60,7 @@ bundle = VersionBundle(
AddVariableKeysEndpoint,
AddConnectionTestEndpoint,
AddAwaitingInputStatePayload,
+ AddTaskInstanceQueueField,
),
Version(
"2026-06-16",
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
index f9c22f6cd82..cfa5b616396 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
@@ -19,7 +19,10 @@ from __future__ import annotations
from cadwyn import VersionChange, endpoint, schema
-from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TIAwaitingInputStatePayload
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ TaskInstance,
+ TIAwaitingInputStatePayload,
+)
class AddVariableKeysEndpoint(VersionChange):
@@ -41,6 +44,14 @@ class AddConnectionTestEndpoint(VersionChange):
)
+class AddTaskInstanceQueueField(VersionChange):
+ """Add the `queue` field to the TaskInstance model."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version =
(schema(TaskInstance).field("queue").didnt_exist,)
+
+
class AddAwaitingInputStatePayload(VersionChange):
"""Add the awaiting_input task instance state transition payload
(Human-in-the-loop, no trigger)."""
diff --git a/airflow-core/src/airflow/executors/workloads/task.py
b/airflow-core/src/airflow/executors/workloads/task.py
index 9af3f33c10e..b4bf02ea47b 100644
--- a/airflow-core/src/airflow/executors/workloads/task.py
+++ b/airflow-core/src/airflow/executors/workloads/task.py
@@ -18,12 +18,12 @@
from __future__ import annotations
-import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Literal
-from pydantic import BaseModel, Field
+from pydantic import Field
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TaskInstance
from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo
from airflow.utils.state import TaskInstanceState
@@ -33,36 +33,20 @@ if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
-class BaseTaskInstanceDTO(BaseModel):
+class TaskInstanceDTO(TaskInstance):
"""
- Base schema for TaskInstance with the minimal fields shared by Executors
and the Task SDK.
+ The versioned execution API ``TaskInstance`` schema plus executor-only
fields.
- This definition is duplicated in
:mod:`airflow.sdk.execution_time.workloads.task`
- and the two are kept in sync by the ``check-task-instance-dto-sync`` prek
- hook. Update both files together.
+ The base class is the single source of truth for the fields a worker needs;
+ the fields added here are executor concerns (queueing order and pool
+ accounting) the worker never reads.
"""
- id: uuid.UUID
- dag_version_id: uuid.UUID
- task_id: str
- dag_id: str
- run_id: str
- try_number: int
- map_index: int = -1
-
pool_slots: int
- queue: str
priority_weight: int
- executor_config: dict | None = Field(default=None, exclude=True)
-
- parent_context_carrier: dict | None = None
- context_carrier: dict | None = None
-
-
-class TaskInstanceDTO(BaseTaskInstanceDTO):
- """TaskInstanceDTO with executor-specific ``external_executor_id`` field
and ``key`` property."""
external_executor_id: str | None = Field(default=None, exclude=True)
+ executor_config: dict | None = Field(default=None, exclude=True)
# TODO: Task-SDK: Can we replace TaskInstanceKey with just the uuid across
the codebase?
@property
diff --git a/airflow-core/tests/unit/executors/test_workloads.py
b/airflow-core/tests/unit/executors/test_workloads.py
index 6f027a4d3be..a063b91140d 100644
--- a/airflow-core/tests/unit/executors/test_workloads.py
+++ b/airflow-core/tests/unit/executors/test_workloads.py
@@ -32,6 +32,7 @@ from airflow.executors.workloads.callback import CallbackDTO,
CallbackFetchMetho
from airflow.executors.workloads.task import ExecuteTask
from airflow.executors.workloads.types import state_class_for_key
from airflow.models.callback import CallbackKey
+from airflow.sdk.api.datamodels._generated import TaskInstance as
GeneratedTaskInstance
def test_task_instance_alias_keeps_backwards_compat():
@@ -134,3 +135,39 @@ def test_callback_dto_key_returns_callback_key_instance():
assert isinstance(key, CallbackKey)
assert key.id == cid
assert str(key) == cid
+
+
+def test_workload_ti_round_trips_through_sdk_generated_model():
+ """
+ The executor-side DTO and the SDK's generated TaskInstance share the
+ execution API schema; the serialized workload must carry the routing
+ fields and exclude the executor-only ones.
+ """
+ ti = TaskInstanceDTO(
+ id=uuid4(),
+ dag_version_id=uuid4(),
+ task_id="test_task",
+ dag_id="test_dag",
+ run_id="test_run",
+ try_number=2,
+ map_index=3,
+ pool_slots=4,
+ queue="jdk-17",
+ priority_weight=5,
+ external_executor_id="celery-id",
+ executor_config={"KubernetesExecutor": {"image": "custom"}},
+ )
+
+ dumped = ti.model_dump(mode="json")
+ assert "external_executor_id" not in dumped
+ assert "executor_config" not in dumped
+ # Executor-side scheduling fields stay on the workload wire (older workers
+ # deserialize the workload with a model that requires them) but are not
+ # part of the worker-facing schema.
+ assert dumped["pool_slots"] == 4
+ assert dumped["priority_weight"] == 5
+
+ received = GeneratedTaskInstance.model_validate(dumped)
+ assert received.queue == "jdk-17"
+ assert received.map_index == 3
+ assert not hasattr(received, "pool_slots")
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 0958dbd663a..eaac4e37fe0 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2584,16 +2584,12 @@ def create_runtime_ti(mocked_parse):
should_retry: bool | None = None,
max_tries: int | None = None,
) -> RuntimeTaskInstance:
- from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
-
- if AIRFLOW_V_3_3_PLUS:
- from airflow.sdk.execution_time.workloads.task import
TaskInstanceDTO
- else:
- from airflow.sdk.api.datamodels._generated import ( # type:
ignore[no-redef,assignment]
- TaskInstance as TaskInstanceDTO,
- )
-
- from airflow.sdk.api.datamodels._generated import DagRun, DagRunState,
TIRunContext
+ from airflow.sdk.api.datamodels._generated import (
+ DagRun,
+ DagRunState,
+ TaskInstance as TaskInstanceDTO,
+ TIRunContext,
+ )
from airflow.utils.types import DagRunType
if isinstance(logical_date, str):
@@ -2678,9 +2674,7 @@ def create_runtime_ti(mocked_parse):
try_number=try_number,
map_index=map_index, # type: ignore[arg-type]
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="",
bundle_info=BundleInfo(name="anything", version="any"),
diff --git
a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskTest.kt
b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskTest.kt
index 1bf91220f93..0542516228d 100644
--- a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskTest.kt
+++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskTest.kt
@@ -29,7 +29,7 @@ import org.apache.airflow.sdk.execution.comm.DagRun
import org.apache.airflow.sdk.execution.comm.StartupDetails
import org.apache.airflow.sdk.execution.comm.SucceedTask
import org.apache.airflow.sdk.execution.comm.TIRunContext
-import org.apache.airflow.sdk.execution.comm.TaskInstanceDTO
+import org.apache.airflow.sdk.execution.comm.TaskInstance
import org.apache.airflow.sdk.execution.comm.TaskState
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.DisplayName
@@ -76,7 +76,7 @@ class TaskTest {
private fun startupDetails(taskId: String): StartupDetails =
StartupDetails().also {
it.ti =
- TaskInstanceDTO().also { o ->
+ TaskInstance().also { o ->
o.id = UUID.randomUUID()
o.taskId = taskId
o.dagId = "test_dag"
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index f350c884981..ca6fab54fa2 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -1303,6 +1303,7 @@ class TestAwsEcsExecutor:
task.pool_slots = 1
task.priority_weight = 1
task.context_carrier = {}
+ task.hostname = None
task.queued_dttm = dt.datetime.now()
# Set up nested attributes for BundleInfo
task.dag_model = mock.Mock()
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 408cb140783..7a7fb194fb0 100644
---
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -1277,10 +1277,6 @@ components:
type: string
format: uuid
title: Id
- dag_version_id:
- type: string
- format: uuid
- title: Dag Version Id
task_id:
type: string
title: Task Id
@@ -1293,45 +1289,55 @@ components:
try_number:
type: integer
title: Try Number
+ dag_version_id:
+ type: string
+ format: uuid
+ title: Dag Version Id
map_index:
type: integer
title: Map Index
default: -1
- pool_slots:
- type: integer
- title: Pool Slots
- queue:
- type: string
- title: Queue
- priority_weight:
- type: integer
- title: Priority Weight
- parent_context_carrier:
+ hostname:
anyOf:
- - additionalProperties: true
- type: object
+ - type: string
- type: 'null'
- title: Parent Context Carrier
+ title: Hostname
context_carrier:
anyOf:
- additionalProperties: true
type: object
- type: 'null'
title: Context Carrier
+ queue:
+ type: string
+ title: Queue
+ default: default
+ pool_slots:
+ type: integer
+ title: Pool Slots
+ priority_weight:
+ type: integer
+ title: Priority Weight
type: object
required:
- id
- - dag_version_id
- task_id
- dag_id
- run_id
- try_number
+ - dag_version_id
- pool_slots
- - queue
- priority_weight
title: TaskInstanceDTO
- description: TaskInstanceDTO with executor-specific
``external_executor_id``
- field and ``key`` property.
+ description: 'The versioned execution API ``TaskInstance`` schema plus
executor-only
+ fields.
+
+
+ The base class is the single source of truth for the fields a worker
needs;
+
+ the fields added here are executor concerns (queueing order and pool
+
+ accounting) the worker never reads.'
TaskInstanceState:
type: string
enum:
diff --git a/scripts/ci/prek/check_task_instance_dto_sync.py
b/scripts/ci/prek/check_task_instance_dto_sync.py
deleted file mode 100755
index 689d35a4d15..00000000000
--- a/scripts/ci/prek/check_task_instance_dto_sync.py
+++ /dev/null
@@ -1,125 +0,0 @@
-#!/usr/bin/env python
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Verify that the duplicate ``BaseTaskInstanceDTO`` definitions in airflow-core
-and task-sdk stay structurally identical.
-
-``BaseTaskInstanceDTO`` is duplicated (not shared) in:
-
-- ``airflow-core/src/airflow/executors/workloads/task.py``
-- ``task-sdk/src/airflow/sdk/execution_time/workloads/task.py``
-
-This hook compares the *fields* (annotated assignments) and bases of both
-``BaseTaskInstanceDTO`` classes. The concrete ``TaskInstanceDTO`` subclasses
-in each file are allowed to differ (airflow-core adds an executor-specific
-``key`` property that depends on ``airflow.models``, which the Task SDK
-does not have access to).
-"""
-
-from __future__ import annotations
-
-import ast
-import sys
-from pathlib import Path
-
-AIRFLOW_ROOT = Path(__file__).parents[3].resolve()
-CORE_FILE = AIRFLOW_ROOT / "airflow-core" / "src" / "airflow" / "executors" /
"workloads" / "task.py"
-SDK_FILE = AIRFLOW_ROOT / "task-sdk" / "src" / "airflow" / "sdk" /
"execution_time" / "workloads" / "task.py"
-CLASS_NAME = "BaseTaskInstanceDTO"
-
-
-def _find_class(tree: ast.AST, class_name: str) -> ast.ClassDef | None:
- for node in ast.walk(tree):
- if isinstance(node, ast.ClassDef) and node.name == class_name:
- return node
- return None
-
-
-def _field_signature(class_node: ast.ClassDef) -> list[tuple[str, str, str |
None]]:
- """Return a normalized list of ``(name, annotation, default)`` for each
field."""
- fields: list[tuple[str, str, str | None]] = []
- for stmt in class_node.body:
- if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target,
ast.Name):
- name = stmt.target.id
- annotation = ast.unparse(stmt.annotation)
- default = ast.unparse(stmt.value) if stmt.value is not None else
None
- fields.append((name, annotation, default))
- return fields
-
-
-def _bases(class_node: ast.ClassDef) -> list[str]:
- return [ast.unparse(base) for base in class_node.bases]
-
-
-def _extract(file_path: Path) -> tuple[list[str], list[tuple[str, str, str |
None]]]:
- source = file_path.read_text()
- tree = ast.parse(source, filename=str(file_path))
- class_node = _find_class(tree, CLASS_NAME)
- if class_node is None:
- print(f"ERROR: Could not find class {CLASS_NAME} in {file_path}",
file=sys.stderr)
- sys.exit(1)
- return _bases(class_node), _field_signature(class_node)
-
-
-def main() -> None:
- core_bases, core_fields = _extract(CORE_FILE)
- sdk_bases, sdk_fields = _extract(SDK_FILE)
-
- if core_bases == sdk_bases and core_fields == sdk_fields:
- sys.exit(0)
-
- print(
- f"\nERROR: {CLASS_NAME} definitions in airflow-core and task-sdk are
out of sync!",
- file=sys.stderr,
- )
- print(f"\n airflow-core: {CORE_FILE.relative_to(AIRFLOW_ROOT)}",
file=sys.stderr)
- print(f" task-sdk: {SDK_FILE.relative_to(AIRFLOW_ROOT)}",
file=sys.stderr)
-
- if core_bases != sdk_bases:
- print("\nClass bases differ:", file=sys.stderr)
- print(f" airflow-core: {core_bases}", file=sys.stderr)
- print(f" task-sdk: {sdk_bases}", file=sys.stderr)
-
- if core_fields != sdk_fields:
- core_set = {f[0]: f for f in core_fields}
- sdk_set = {f[0]: f for f in sdk_fields}
- only_in_core = sorted(set(core_set) - set(sdk_set))
- only_in_sdk = sorted(set(sdk_set) - set(core_set))
- differing = sorted(name for name in set(core_set) & set(sdk_set) if
core_set[name] != sdk_set[name])
- if only_in_core:
- print(f"\n Fields only in airflow-core: {only_in_core}",
file=sys.stderr)
- if only_in_sdk:
- print(f"\n Fields only in task-sdk: {only_in_sdk}",
file=sys.stderr)
- for name in differing:
- print(
- f"\n Field {name!r} differs:"
- f"\n airflow-core: {core_set[name]}"
- f"\n task-sdk: {sdk_set[name]}",
- file=sys.stderr,
- )
-
- print(
- f"\nUpdate both files together so the two {CLASS_NAME} definitions
stay in sync.",
- file=sys.stderr,
- )
- sys.exit(1)
-
-
-if __name__ == "__main__":
- main()
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 9bdd8ac4fba..99aedae25b0 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -560,6 +560,7 @@ class TaskInstance(BaseModel):
map_index: Annotated[int | None, Field(title="Map Index")] = -1
hostname: Annotated[str | None, Field(title="Hostname")] = None
context_carrier: Annotated[dict[str, Any] | None, Field(title="Context
Carrier")] = None
+ queue: Annotated[str | None, Field(title="Queue")] = "default"
class BundleInfo(BaseModel):
diff --git a/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
index 49bea634ac5..bac9f828b78 100644
--- a/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
+++ b/task-sdk/src/airflow/sdk/coordinators/_subprocess.py
@@ -51,8 +51,7 @@ if TYPE_CHECKING:
from typing_extensions import Self
from airflow.sdk.api.client import Client
- from airflow.sdk.api.datamodels._generated import BundleInfo
- from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+ from airflow.sdk.api.datamodels._generated import BundleInfo, TaskInstance
Tracked = TypeVar("Tracked", socket.socket, subprocess.Popen)
@@ -279,7 +278,7 @@ class _PopenActivitySubprocess(ActivitySubprocess):
def start( # type: ignore[override]
cls,
*,
- what: TaskInstanceDTO,
+ what: TaskInstance,
dag_rel_path: str | os.PathLike[str],
bundle_info,
logger: FilteringBoundLogger | None = None,
@@ -369,7 +368,7 @@ class SubprocessCoordinator(BaseCoordinator):
task_startup_timeout: float = 10.0
- def _build_execute_task_command(self, *, what: TaskInstanceDTO) ->
tuple[list[str], str | None]:
+ def _build_execute_task_command(self, *, what: TaskInstance) ->
tuple[list[str], str | None]:
"""
Build the subprocess command and resolve its supervisor wire-schema
version for *what*.
@@ -385,7 +384,7 @@ class SubprocessCoordinator(BaseCoordinator):
def execute_task(
self,
*,
- what: TaskInstanceDTO,
+ what: TaskInstance,
dag_rel_path: str | os.PathLike[str],
bundle_info: BundleInfo,
client: Client,
diff --git a/task-sdk/src/airflow/sdk/coordinators/executable/coordinator.py
b/task-sdk/src/airflow/sdk/coordinators/executable/coordinator.py
index 7b5290be021..ed11f97165d 100644
--- a/task-sdk/src/airflow/sdk/coordinators/executable/coordinator.py
+++ b/task-sdk/src/airflow/sdk/coordinators/executable/coordinator.py
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger
from typing_extensions import Self
- from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+ from airflow.sdk.api.datamodels._generated import TaskInstance
log: FilteringBoundLogger =
structlog.get_logger(logger_name="coordinators.executable")
@@ -390,6 +390,6 @@ class ExecutableCoordinator(SubprocessCoordinator):
validator=attrs.validators.min_len(1),
)
- def _build_execute_task_command(self, *, what: TaskInstanceDTO) ->
tuple[list[str], str | None]:
+ def _build_execute_task_command(self, *, what: TaskInstance) ->
tuple[list[str], str | None]:
bundle = _Bundle.find(self.executables_root, what.dag_id)
return [str(bundle.path)], bundle.schema_version
diff --git a/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
b/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
index d6aebe707f5..cb74c64dfd4 100644
--- a/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
+++ b/task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger
from typing_extensions import Self
- from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+ from airflow.sdk.api.datamodels._generated import TaskInstance
log: FilteringBoundLogger =
structlog.get_logger(logger_name="coordinators.java")
@@ -219,7 +219,7 @@ class JavaCoordinator(SubprocessCoordinator):
)
main_class: str = ""
- def _build_execute_task_command(self, *, what: TaskInstanceDTO) ->
tuple[list[str], str | None]:
+ def _build_execute_task_command(self, *, what: TaskInstance) ->
tuple[list[str], str | None]:
jar = _JarInfo.find(self.jars_root, self.main_class)
command = [
self.java_executable,
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 330fe934c7c..c2c4fd69516 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -80,6 +80,7 @@ from airflow.sdk.api.datamodels._generated import (
PreviousTIResponse,
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
+ TaskInstance,
TaskInstanceState,
TaskStatesResponse,
TaskStoreResponse,
@@ -98,10 +99,6 @@ from airflow.sdk.api.datamodels._generated import (
XComSequenceSliceResponse,
)
from airflow.sdk.exceptions import ErrorType
-from airflow.sdk.execution_time.workloads.task import (
- # Pydantic needs this at runtime since we don't model_rebuild()
StartupDetails.
- TaskInstanceDTO, # noqa: TC001
-)
try:
from socket import recv_fds
@@ -338,7 +335,7 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
class StartupDetails(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
- ti: TaskInstanceDTO
+ ti: TaskInstance
dag_rel_path: str
bundle_info: BundleInfo
start_date: datetime
diff --git a/task-sdk/src/airflow/sdk/execution_time/coordinator.py
b/task-sdk/src/airflow/sdk/execution_time/coordinator.py
index 072e9c0b6c7..e233ab8e4ae 100644
--- a/task-sdk/src/airflow/sdk/execution_time/coordinator.py
+++ b/task-sdk/src/airflow/sdk/execution_time/coordinator.py
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from typing_extensions import Self
from airflow.sdk.api.client import Client
- from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+ from airflow.sdk.api.datamodels._generated import TaskInstance
__all__ = [
"BaseCoordinator",
@@ -89,7 +89,7 @@ class BaseCoordinator:
def execute_task(
self,
*,
- what: TaskInstanceDTO,
+ what: TaskInstance,
dag_rel_path: str | PathLike[str],
bundle_info,
client: Client,
@@ -122,7 +122,7 @@ class _PythonCoordinator(BaseCoordinator):
def execute_task(
self,
*,
- what: TaskInstanceDTO,
+ what: TaskInstance,
dag_rel_path: str | PathLike[str],
bundle_info,
client: Client,
diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
index c02aa76791e..0d1a3dc61c7 100644
--- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
+++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
@@ -3656,7 +3656,7 @@
"StartupDetails": {
"properties": {
"ti": {
- "$ref": "#/$defs/TaskInstanceDTO"
+ "$ref": "#/$defs/TaskInstance"
},
"dag_rel_path": {
"title": "Dag Rel Path",
@@ -3891,164 +3891,6 @@
"title": "TaskCallbackRequest",
"type": "object"
},
- "TaskInstance": {
- "description": "Schema for TaskInstance model with minimal required
fields needed for Runtime.",
- "properties": {
- "id": {
- "format": "uuid",
- "title": "Id",
- "type": "string"
- },
- "task_id": {
- "title": "Task Id",
- "type": "string"
- },
- "dag_id": {
- "title": "Dag Id",
- "type": "string"
- },
- "run_id": {
- "title": "Run Id",
- "type": "string"
- },
- "try_number": {
- "title": "Try Number",
- "type": "integer"
- },
- "dag_version_id": {
- "format": "uuid",
- "title": "Dag Version Id",
- "type": "string"
- },
- "map_index": {
- "default": -1,
- "title": "Map Index",
- "type": "integer"
- },
- "hostname": {
- "anyOf": [
- {
- "type": "string"
- },
- {
- "type": "null"
- }
- ],
- "default": null,
- "title": "Hostname"
- },
- "context_carrier": {
- "anyOf": [
- {
- "additionalProperties": true,
- "type": "object"
- },
- {
- "type": "null"
- }
- ],
- "default": null,
- "title": "Context Carrier"
- }
- },
- "required": [
- "id",
- "task_id",
- "dag_id",
- "run_id",
- "try_number",
- "dag_version_id"
- ],
- "title": "TaskInstance",
- "type": "object"
- },
- "TaskInstanceDTO": {
- "description": "Task SDK TaskInstanceDTO.",
- "properties": {
- "id": {
- "format": "uuid",
- "title": "Id",
- "type": "string"
- },
- "dag_version_id": {
- "format": "uuid",
- "title": "Dag Version Id",
- "type": "string"
- },
- "task_id": {
- "title": "Task Id",
- "type": "string"
- },
- "dag_id": {
- "title": "Dag Id",
- "type": "string"
- },
- "run_id": {
- "title": "Run Id",
- "type": "string"
- },
- "try_number": {
- "title": "Try Number",
- "type": "integer"
- },
- "map_index": {
- "default": -1,
- "title": "Map Index",
- "type": "integer"
- },
- "pool_slots": {
- "title": "Pool Slots",
- "type": "integer"
- },
- "queue": {
- "title": "Queue",
- "type": "string"
- },
- "priority_weight": {
- "title": "Priority Weight",
- "type": "integer"
- },
- "parent_context_carrier": {
- "anyOf": [
- {
- "additionalProperties": true,
- "type": "object"
- },
- {
- "type": "null"
- }
- ],
- "default": null,
- "title": "Parent Context Carrier"
- },
- "context_carrier": {
- "anyOf": [
- {
- "additionalProperties": true,
- "type": "object"
- },
- {
- "type": "null"
- }
- ],
- "default": null,
- "title": "Context Carrier"
- }
- },
- "required": [
- "id",
- "dag_version_id",
- "task_id",
- "dag_id",
- "run_id",
- "try_number",
- "pool_slots",
- "queue",
- "priority_weight"
- ],
- "title": "TaskInstanceDTO",
- "type": "object"
- },
"TaskInstanceState": {
"description": "All possible states that a Task Instance can be
in.\n\nNote that None is also allowed, so always use this in a type hint with
Optional.",
"enum": [
@@ -4930,6 +4772,82 @@
"title": "TIRunContext",
"type": "object"
},
+ "TaskInstance": {
+ "description": "Schema for TaskInstance model with minimal required
fields needed for Runtime.",
+ "properties": {
+ "id": {
+ "format": "uuid",
+ "title": "Id",
+ "type": "string"
+ },
+ "task_id": {
+ "title": "Task Id",
+ "type": "string"
+ },
+ "dag_id": {
+ "title": "Dag Id",
+ "type": "string"
+ },
+ "run_id": {
+ "title": "Run Id",
+ "type": "string"
+ },
+ "try_number": {
+ "title": "Try Number",
+ "type": "integer"
+ },
+ "dag_version_id": {
+ "format": "uuid",
+ "title": "Dag Version Id",
+ "type": "string"
+ },
+ "map_index": {
+ "default": -1,
+ "title": "Map Index",
+ "type": "integer"
+ },
+ "hostname": {
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "default": null,
+ "title": "Hostname"
+ },
+ "context_carrier": {
+ "anyOf": [
+ {
+ "additionalProperties": true,
+ "type": "object"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "default": null,
+ "title": "Context Carrier"
+ },
+ "queue": {
+ "default": "default",
+ "title": "Queue",
+ "type": "string"
+ }
+ },
+ "required": [
+ "id",
+ "task_id",
+ "dag_id",
+ "run_id",
+ "try_number",
+ "dag_version_id"
+ ],
+ "title": "TaskInstance",
+ "type": "object"
+ },
"VariableResponse": {
"additionalProperties": false,
"description": "Variable schema for responses with fields that are
needed for Runtime.",
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 6527e651041..4ea82c45327 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -172,7 +172,6 @@ if TYPE_CHECKING:
from airflow.executors.workloads import BundleInfo
from airflow.sdk.bases.secrets_backend import BaseSecretsBackend
from airflow.sdk.definitions.connection import Connection
- from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise",
"supervise_task"]
@@ -1317,7 +1316,7 @@ class ActivitySubprocess(WatchedSubprocess):
def start( # type: ignore[override]
cls,
*,
- what: TaskInstanceDTO,
+ what: TaskInstance,
dag_rel_path: str | os.PathLike[str],
bundle_info,
client: Client,
@@ -1346,7 +1345,7 @@ class ActivitySubprocess(WatchedSubprocess):
def _on_child_started(
self,
*,
- ti: TaskInstanceDTO,
+ ti: TaskInstance,
dag_rel_path: str | os.PathLike[str],
bundle_info,
sentry_integration: str,
@@ -2412,7 +2411,7 @@ def _configure_logging(log_path: str, client: Client) ->
tuple[FilteringBoundLog
def supervise_task(
*,
- ti: TaskInstanceDTO,
+ ti: TaskInstance,
bundle_info: BundleInfo,
dag_rel_path: str | os.PathLike[str],
token: str,
@@ -2482,7 +2481,7 @@ def supervise_task(
raise ValueError("dag_path is required")
try:
- coordinator = get_coordinator_manager().for_queue(ti.queue)
+ coordinator = get_coordinator_manager().for_queue(ti.queue or
"default")
except:
log.exception(
"Failed to initialize coordinator for task",
diff --git a/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py
b/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py
deleted file mode 100644
index cdf955e742d..00000000000
--- a/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Workload schemas for Task SDK execution-time communication."""
-
-from __future__ import annotations
-
-from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
-
-__all__ = ["TaskInstanceDTO"]
diff --git a/task-sdk/src/airflow/sdk/execution_time/workloads/task.py
b/task-sdk/src/airflow/sdk/execution_time/workloads/task.py
deleted file mode 100644
index ceff200856f..00000000000
--- a/task-sdk/src/airflow/sdk/execution_time/workloads/task.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Task workload schemas for Task SDK execution-time communication."""
-
-from __future__ import annotations
-
-import uuid
-
-from pydantic import BaseModel, Field
-
-
-class BaseTaskInstanceDTO(BaseModel):
- """
- Base schema for TaskInstance with the minimal fields shared by Executors
and the Task SDK.
-
- This class is duplicated in :mod:`airflow.executors.workloads.task` and the
- two definitions are kept in sync by the ``check-task-instance-dto-sync``
- prek hook. Update both files together.
- """
-
- id: uuid.UUID
- dag_version_id: uuid.UUID
- task_id: str
- dag_id: str
- run_id: str
- try_number: int
- map_index: int = -1
-
- pool_slots: int
- queue: str
- priority_weight: int
- executor_config: dict | None = Field(default=None, exclude=True)
-
- parent_context_carrier: dict | None = None
- context_carrier: dict | None = None
-
-
-class TaskInstanceDTO(BaseTaskInstanceDTO):
- """Task SDK TaskInstanceDTO."""
diff --git
a/task-sdk/tests/task_sdk/coordinators/executable/test_coordinator.py
b/task-sdk/tests/task_sdk/coordinators/executable/test_coordinator.py
index 23b45e4f144..ec426a09ed7 100644
--- a/task-sdk/tests/task_sdk/coordinators/executable/test_coordinator.py
+++ b/task-sdk/tests/task_sdk/coordinators/executable/test_coordinator.py
@@ -31,6 +31,7 @@ import pytest
import yaml
from uuid6 import uuid7
+from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.coordinators.executable.coordinator import (
FOOTER_MAGIC,
FOOTER_SIZE,
@@ -41,7 +42,6 @@ from airflow.sdk.coordinators.executable.coordinator import (
)
from airflow.sdk.execution_time.coordinator import BaseCoordinator
from airflow.sdk.execution_time.supervisor import ActivitySubprocess
-from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
@@ -111,8 +111,8 @@ def _make_executable(path: Path) -> Path:
return path
-def _make_ti(dag_id: str = "tutorial_dag", queue: str = "executable") ->
TaskInstanceDTO:
- return TaskInstanceDTO(
+def _make_ti(dag_id: str = "tutorial_dag", queue: str = "executable") ->
TaskInstance:
+ return TaskInstance(
id=uuid7(),
dag_version_id=uuid7(),
task_id="task_1",
@@ -120,9 +120,7 @@ def _make_ti(dag_id: str = "tutorial_dag", queue: str =
"executable") -> TaskIns
run_id="run_1",
try_number=1,
map_index=-1,
- pool_slots=1,
queue=queue,
- priority_weight=1,
)
diff --git a/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
b/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
index 8670a0e895c..c30ecff35b9 100644
--- a/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
+++ b/task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
@@ -29,6 +29,7 @@ from unittest.mock import MagicMock, patch
import pytest
from uuid6 import uuid7
+from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.coordinators.java.coordinator import (
JavaCoordinator,
_calculate_classpath,
@@ -37,7 +38,6 @@ from airflow.sdk.coordinators.java.coordinator import (
)
from airflow.sdk.execution_time.coordinator import BaseCoordinator
from airflow.sdk.execution_time.supervisor import ActivitySubprocess
-from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
@@ -45,8 +45,8 @@ if not AIRFLOW_V_3_3_PLUS:
pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0",
allow_module_level=True)
-def _make_ti(dag_id: str = "test_dag", queue: str = "java") -> TaskInstanceDTO:
- return TaskInstanceDTO(
+def _make_ti(dag_id: str = "test_dag", queue: str = "java") -> TaskInstance:
+ return TaskInstance(
id=uuid7(),
dag_version_id=uuid7(),
task_id="task_1",
@@ -54,9 +54,7 @@ def _make_ti(dag_id: str = "test_dag", queue: str = "java")
-> TaskInstanceDTO:
run_id="run_1",
try_number=1,
map_index=-1,
- pool_slots=1,
queue=queue,
- priority_weight=1,
)
diff --git a/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
index 2c043526a4e..8bdd905c6b0 100644
--- a/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
+++ b/task-sdk/tests/task_sdk/coordinators/test_subprocess.py
@@ -32,6 +32,7 @@ import pytest
from uuid6 import uuid7
from airflow.sdk.api.client import Client, TaskInstanceOperations
+from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.coordinators._subprocess import (
SubprocessCoordinator,
_accept_connections,
@@ -43,7 +44,6 @@ from airflow.sdk.coordinators._subprocess import (
)
from airflow.sdk.execution_time.coordinator import BaseCoordinator
from airflow.sdk.execution_time.supervisor import ActivitySubprocess
-from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
@@ -51,8 +51,8 @@ if not AIRFLOW_V_3_3_PLUS:
pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0",
allow_module_level=True)
-def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") ->
TaskInstanceDTO:
- return TaskInstanceDTO(
+def _make_ti(dag_id: str = "tutorial_dag", queue: str = "socket") ->
TaskInstance:
+ return TaskInstance(
id=uuid7(),
dag_version_id=uuid7(),
task_id="task_1",
@@ -60,9 +60,7 @@ def _make_ti(dag_id: str = "tutorial_dag", queue: str =
"socket") -> TaskInstanc
run_id="run_1",
try_number=1,
map_index=-1,
- pool_slots=1,
queue=queue,
- priority_weight=1,
)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 62a9d00fdf2..8ae91b8c15e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -65,6 +65,7 @@ from airflow.sdk.api.datamodels._generated import (
DagRunState,
DagRunType,
PreviousTIResponse,
+ TaskInstance,
TaskInstanceState,
)
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType,
TaskAlreadyRunningError
@@ -172,7 +173,6 @@ from airflow.sdk.execution_time.supervisor import (
supervise_task,
)
from airflow.sdk.execution_time.task_runner import run
-from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from tests_common.test_utils.config import conf_vars
@@ -235,16 +235,14 @@ class TestSupervisor:
"""
Test that the supervisor validates server URL and dry_run parameter
combinations correctly.
"""
- ti = TaskInstanceDTO(
+ ti = TaskInstance(
id=uuid7(),
task_id="async",
dag_id="super_basic_deferred_run",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
)
bundle_info = BundleInfo(name="my-bundle", version=None)
@@ -271,16 +269,14 @@ class TestSupervisor:
client_with_ti_start,
):
"""SIGTERM to the supervisor process is forwarded to the task
subprocess."""
- ti = TaskInstanceDTO(
+ ti = TaskInstance(
id=uuid7(),
task_id="signal_task",
dag_id="signal_forward_test",
run_id="r",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
)
bundle_info = BundleInfo(name="my-bundle", version=None)
@@ -379,16 +375,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -457,16 +451,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -557,16 +549,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
@@ -589,16 +579,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -627,16 +615,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=uuid7(),
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=mock_client,
target=subprocess_main,
@@ -674,16 +660,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=uuid7(),
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=mock_client,
target=lambda: None,
@@ -720,16 +704,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=sdk_client.Client(base_url="", dry_run=True, token=""),
target=subprocess_main,
@@ -765,16 +747,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=sdk_client.Client(base_url="", dry_run=True, token=""),
target=subprocess_main,
@@ -789,16 +769,14 @@ class TestWatchedSubprocess:
time_machine.move_to(instant, tick=False)
dagfile_path = test_dags_dir
- ti = TaskInstanceDTO(
+ ti = TaskInstance(
id=uuid7(),
task_id="hello",
dag_id="super_basic_run",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
)
bundle_info = BundleInfo(name="my-bundle", version=None)
@@ -833,16 +811,14 @@ class TestWatchedSubprocess:
"""
instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 0)
- ti = TaskInstanceDTO(
+ ti = TaskInstance(
id=uuid7(),
task_id="async",
dag_id="super_basic_deferred_run",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
)
# Create a mock client to assert calls to the client
@@ -963,16 +939,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
@@ -1049,16 +1023,14 @@ class TestWatchedSubprocess:
ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
@@ -1262,16 +1234,14 @@ class TestWatchedSubprocess:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -1425,16 +1395,14 @@ class TestWatchedSubprocessKill:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -3454,16 +3422,14 @@ class TestInProcessTestSupervisor:
task.dag = DAG(dag_id="test_dag")
# Create a simple TaskInstance datamodel to pass to the supervisor
- ti = TaskInstanceDTO(
+ ti = TaskInstance(
id=uuid7(),
dag_version_id=uuid7(),
dag_id="test_dag",
task_id=task.task_id,
run_id="r",
try_number=1,
- pool_slots=1,
queue="default",
- priority_weight=1,
)
# Patch the API client used by InProcessTestSupervisor to return a
predictable TI context
@@ -3880,16 +3846,14 @@ def test_reinit_supervisor_comms(monkeypatch,
client_with_ti_start, caplog):
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstanceDTO(
+ what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 471541a6429..48b51390469 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -176,7 +176,6 @@ from airflow.sdk.execution_time.task_runner import (
run,
startup,
)
-from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from airflow.sdk.execution_time.xcom import XCom
from airflow.sdk.serde import deserialize
from airflow.triggers.base import BaseEventTrigger, BaseTrigger, TriggerEvent
@@ -208,16 +207,14 @@ class CustomOperator(BaseOperator):
def test_parse(test_dags_dir: Path, make_ti_context):
"""Test that checks parsing of a basic dag with an un-mocked parse."""
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="a",
dag_id="super_basic",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -263,16 +260,14 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path,
make_ti_context):
mock_dag.tasks = [mock_task]
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="a",
dag_id="super_basic",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -326,16 +321,14 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path,
make_ti_context):
def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id,
task_id, expected_error):
"""Check for nice error messages on dag not found."""
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id=task_id,
dag_id=dag_id,
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -375,16 +368,14 @@ def
test_parse_not_found_does_not_reschedule_when_max_attempts_reached(test_dags
and should surface as a hard failure (SystemExit in the task runner
process).
"""
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="a",
dag_id="madeup_dag_id",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -439,16 +430,14 @@ def
test_main_sends_reschedule_task_when_startup_reschedules(
mock_comms_instance.socket = None
mock_comms_decoder_cls.__getitem__.return_value.return_value =
mock_comms_instance
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="my_task",
dag_id="test_dag",
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
context_carrier={},
),
dag_rel_path="",
@@ -615,16 +604,14 @@ def
test_task_span_is_child_of_dag_run_span(make_ti_context):
# Step 3: build StartupDetails with ti.context_carrier = ti_carrier.
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="my_task",
dag_id="test_dag",
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
context_carrier=ti_carrier,
),
dag_rel_path="",
@@ -686,16 +673,14 @@ def
test_task_span_no_parent_when_no_context_carrier(make_ti_context):
provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter))
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="standalone_task",
dag_id="test_dag",
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
context_carrier=None,
),
dag_rel_path="",
@@ -730,16 +715,14 @@ def test_parse_module_in_bundle_root(tmp_path: Path,
make_ti_context):
dag1_path.write_text(textwrap.dedent(dag1_code))
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="a",
dag_id="dag_name",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="path_test.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -1180,16 +1163,14 @@ def test_basic_templated_dag(mocked_parse,
make_ti_context, mock_supervisor_comm
)
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="templated_task",
dag_id="basic_templated_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
bundle_info=FAKE_BUNDLE,
dag_rel_path="",
@@ -1299,16 +1280,14 @@ def test_startup_and_run_dag_with_rtif(
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="templated_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1350,16 +1329,14 @@ def test_task_run_with_user_impersonation(
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1401,16 +1378,14 @@ def test_task_run_with_user_impersonation_default_user(
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1444,16 +1419,14 @@ def
test_task_run_with_user_impersonation_remove_krb5ccname_on_reexecuted_proces
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1620,16 +1593,14 @@ def test_dag_parsing_context(make_ti_context,
mock_supervisor_comms, monkeypatch
task_id = "conditional_task"
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id=task_id,
dag_id=dag_id,
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="dag_parsing_context.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -2269,7 +2240,7 @@ class TestRuntimeTaskInstance:
task = CustomOperator(task_id=test_task_id)
# In case of the specific map_index we should check it is passed to TI.
- # ``None`` is not a valid TaskInstanceDTO.map_index value, but
xcom_pull's
+ # ``None`` is not a valid TaskInstance.map_index value, but xcom_pull's
# behaviour with ``map_indexes=None`` is independent of the TI's own
map_index.
extra_for_ti = {"map_index": map_indexes} if isinstance(map_indexes,
int) else {}
runtime_ti = create_runtime_ti(task=task, **extra_for_ti)
@@ -4166,16 +4137,14 @@ class TestTaskRunnerCallsListeners:
task_id="test_task_runner_calls_listeners", do_xcom_push=True,
multiple_outputs=True
)
what = StartupDetails(
- ti=TaskInstanceDTO(
+ ti=TaskInstance(
id=uuid7(),
task_id="templated_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
- pool_slots=1,
queue="default",
- priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,