This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8607e9465b6 Fix edge executor to support handling execute callback
workload (#67679)
8607e9465b6 is described below
commit 8607e9465b64e213855f3f9f4a9351e06914f975
Author: Jeongwoo Do <[email protected]>
AuthorDate: Tue Jun 2 02:27:57 2026 +0900
Fix edge executor to support handling execute callback workload (#67679)
* Fix edge executor to support handling execute callback workload
* fix logic
* fix logic
---
providers/edge3/docs/edge_executor.rst | 33 +++++
.../src/airflow/providers/edge3/cli/worker.py | 20 +--
.../providers/edge3/executors/edge_executor.py | 104 +++++++++------
.../src/airflow/providers/edge3/models/types.py | 46 +++++++
.../providers/edge3/worker_api/datamodels.py | 4 +-
.../providers/edge3/worker_api/routes/jobs.py | 17 ++-
.../edge3/worker_api/v2-edge-generated.yaml | 67 +++++++++-
.../edge3/tests/unit/edge3/cli/test_worker.py | 47 +++++++
.../unit/edge3/executors/test_edge_executor.py | 128 ++++++++++++++++++-
.../edge3/tests/unit/edge3/models/test_types.py | 140 +++++++++++++++++++++
.../unit/edge3/worker_api/routes/test_jobs.py | 87 ++++++++++++-
11 files changed, 640 insertions(+), 53 deletions(-)
diff --git a/providers/edge3/docs/edge_executor.rst
b/providers/edge3/docs/edge_executor.rst
index dd5041375c1..4452b869589 100644
--- a/providers/edge3/docs/edge_executor.rst
+++ b/providers/edge3/docs/edge_executor.rst
@@ -169,6 +169,39 @@ Here is an example setting pool_slots for a task:
task_with_template()
+
+.. _edge_executor:execute_callback:
+
+Support ExecuteCallback in Worker
+---------------------------------
+
+In addition to executing tasks, the EdgeExecutor can also dispatch
executor-level
+callbacks (``ExecuteCallback`` workloads, e.g. deadline callbacks) to edge
workers.
+When the scheduler hands an ``ExecuteCallback`` to
``EdgeExecutor.queue_workload``,
+it is enqueued into the same job queue (``EdgeJobModel``) that is used for task
+workloads, so an edge worker picks it up alongside regular tasks without any
+additional configuration.
+
+Callback jobs share the ``EdgeJobModel`` table with task jobs. They are
+distinguished by reserved values in the identifier columns:
+
+- ``dag_id`` is set to the constant tag ``ExecuteCallback``.
+- ``task_id`` is set to the callback key (the callback ID).
+- ``run_id`` is set to ``ExecuteCallback-<callback_key>``.
+- ``map_index`` is fixed to ``-1`` and ``try_number`` to ``0``.
+
+When the worker fetches such a job through the worker API, the command payload
is
+deserialized back into an ``ExecuteCallback`` workload (instead of an
+``ExecuteTask``) based on these identifiers. On Airflow 3.3+, the worker
executes
+both task and callback workloads via ``BaseExecutor.run_workload`` (or the
+``airflow.sdk.execution_time.execute_workload`` entrypoint when using the
subprocess path).
+
+.. note::
+
+ This feature is only active on Airflow 3.3 or newer. On earlier Airflow
versions
+ the EdgeExecutor only handles ``ExecuteTask`` workloads and any
+ ``ExecuteCallback`` will be rejected with a ``TypeError``.
+
Current Limitations Edge Executor
---------------------------------
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
index 1054e688e76..1a044fdde2e 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
@@ -67,7 +67,7 @@ from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from airflow.configuration import AirflowConfigParser
- from airflow.executors.workloads import ExecuteTask
+ from airflow.providers.edge3.models.types import ExecuteTypeBody
from airflow.providers.edge3.worker_api.datamodels import EdgeJobFetched
logger = logging.getLogger(__name__)
@@ -422,7 +422,7 @@ class EdgeWorker:
return EdgeWorkerState.MAINTENANCE_MODE
return EdgeWorkerState.IDLE
- def _run_job_via_supervisor(self, workload: ExecuteTask, error_file_path:
Path) -> int:
+ def _run_job_via_supervisor(self, workload: ExecuteTypeBody,
error_file_path: Path) -> int:
"""Run a task by calling the supervisor directly (executes inside a
forked child process)."""
_reset_parent_signal_state()
@@ -465,7 +465,7 @@ class EdgeWorker:
error_file_path.write_text(traceback.format_exc())
return 1
- def _launch_job_subprocess(self, workload: ExecuteTask) ->
tuple[subprocess.Popen, Path]:
+ def _launch_job_subprocess(self, workload: ExecuteTypeBody) ->
tuple[subprocess.Popen, Path]:
"""Launch workload via a fresh Python interpreter
(subprocess.Popen)."""
env = os.environ.copy()
if self._execution_api_server_url:
@@ -500,11 +500,11 @@ class EdgeWorker:
logger.info(
"Launched task subprocess pid=%d for %s",
process.pid,
- workload.ti.id,
+ workload.display_name if AIRFLOW_V_3_3_PLUS else workload.ti.id,
)
return process, stderr_file_path
- def _launch_job_fork(self, workload: ExecuteTask) -> tuple[Process, Path]:
+ def _launch_job_fork(self, workload: ExecuteTypeBody) -> tuple[Process,
Path]:
"""Launch workload by forking the current process
(multiprocessing.Process)."""
# Improvement: Use frozen GC to prevent child process from copying
unnecessary memory
# See _spawn_workers_with_gc_freeze() in
airflow-core/src/airflow/executors/local_executor.py
@@ -515,10 +515,14 @@ class EdgeWorker:
kwargs={"workload": workload, "error_file_path": error_file_path},
)
process.start()
- logger.info("Launched task fork pid=%d for %s", process.pid,
workload.ti.id)
+ logger.info(
+ "Launched task fork pid=%d for %s",
+ process.pid,
+ workload.display_name if AIRFLOW_V_3_3_PLUS else workload.ti.id,
+ )
return process, error_file_path
- def _launch_job(self, edge_job: EdgeJobFetched, workload: ExecuteTask,
logfile: Path) -> Job:
+ def _launch_job(self, edge_job: EdgeJobFetched, workload: ExecuteTypeBody,
logfile: Path) -> Job:
"""
Launch a task process.
@@ -673,7 +677,7 @@ class EdgeWorker:
logger.info("Received job: %s", edge_job.identifier)
- workload: ExecuteTask = edge_job.command
+ workload: ExecuteTypeBody = edge_job.command
if TYPE_CHECKING:
assert workload.log_path # We need to assume this is defined in
here
logfile = Path(self.base_log_folder, workload.log_path)
diff --git
a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
index 99d79ffb519..9ac25874c3c 100644
--- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
+++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
@@ -33,6 +33,7 @@ from airflow.providers.edge3.models.db import EdgeDBManager,
check_db_manager_co
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel,
EdgeWorkerState, reset_metrics
+from airflow.providers.edge3.models.types import is_callback_execute
from airflow.utils.db import DBLocks, create_global_lock
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
@@ -101,44 +102,73 @@ class EdgeExecutor(BaseExecutor):
session: Session,
) -> None:
"""Put new workload to queue. Airflow 3 entry point to execute a
task."""
- if not isinstance(workload, workloads.ExecuteTask):
- raise TypeError(f"Don't know how to queue workload of type
{type(workload).__name__}")
-
- task_instance = workload.ti
- key = task_instance.key
-
- # Check if job already exists with same dag_id, task_id, run_id,
map_index, try_number
- existing_job = session.scalars(
- select(EdgeJobModel).where(
- EdgeJobModel.dag_id == key.dag_id,
- EdgeJobModel.task_id == key.task_id,
- EdgeJobModel.run_id == key.run_id,
- EdgeJobModel.map_index == key.map_index,
- EdgeJobModel.try_number == key.try_number,
- )
- ).first()
-
- if existing_job:
- existing_job.state = TaskInstanceState.QUEUED
- existing_job.queue = task_instance.queue
- existing_job.concurrency_slots = task_instance.pool_slots
- existing_job.command = workload.model_dump_json()
- existing_job.team_name = self.team_name
- else:
- session.add(
- EdgeJobModel(
- dag_id=key.dag_id,
- task_id=key.task_id,
- run_id=key.run_id,
- map_index=key.map_index,
- try_number=key.try_number,
- state=TaskInstanceState.QUEUED,
- queue=task_instance.queue,
- concurrency_slots=task_instance.pool_slots,
- command=workload.model_dump_json(),
- team_name=self.team_name,
+ if is_callback_execute(workload):
+ from airflow.providers.edge3.models.types import
EXECUTE_CALLBACK_TAG
+
+ existing_job = session.scalars(
+ select(EdgeJobModel).where(
+ EdgeJobModel.dag_id == EXECUTE_CALLBACK_TAG,
+ EdgeJobModel.task_id == workload.callback.id,
+ EdgeJobModel.run_id ==
f"{EXECUTE_CALLBACK_TAG}-{workload.callback.id}",
)
- )
+ ).first()
+
+ if existing_job:
+ existing_job.state = TaskInstanceState.QUEUED
+ existing_job.command = workload.model_dump_json()
+ else:
+ session.add(
+ EdgeJobModel(
+ dag_id=EXECUTE_CALLBACK_TAG,
+ task_id=str(workload.callback.id),
+
run_id=f"{EXECUTE_CALLBACK_TAG}-{workload.callback.id}",
+ map_index=-1,
+ try_number=0,
+ queue=self.conf.get_mandatory_value("operators",
"default_queue"),
+ concurrency_slots=1,
+ state=TaskInstanceState.QUEUED,
+ command=workload.model_dump_json(),
+ team_name=self.team_name,
+ )
+ )
+ elif isinstance(workload, workloads.ExecuteTask):
+ task_instance = workload.ti
+ key = task_instance.key
+
+ # Check if job already exists with same dag_id, task_id, run_id,
map_index, try_number
+ existing_job = session.scalars(
+ select(EdgeJobModel).where(
+ EdgeJobModel.dag_id == key.dag_id,
+ EdgeJobModel.task_id == key.task_id,
+ EdgeJobModel.run_id == key.run_id,
+ EdgeJobModel.map_index == key.map_index,
+ EdgeJobModel.try_number == key.try_number,
+ )
+ ).first()
+
+ if existing_job:
+ existing_job.state = TaskInstanceState.QUEUED
+ existing_job.queue = task_instance.queue
+ existing_job.concurrency_slots = task_instance.pool_slots
+ existing_job.command = workload.model_dump_json()
+ existing_job.team_name = self.team_name
+ else:
+ session.add(
+ EdgeJobModel(
+ dag_id=key.dag_id,
+ task_id=key.task_id,
+ run_id=key.run_id,
+ map_index=key.map_index,
+ try_number=key.try_number,
+ state=TaskInstanceState.QUEUED,
+ queue=task_instance.queue,
+ concurrency_slots=task_instance.pool_slots,
+ command=workload.model_dump_json(),
+ team_name=self.team_name,
+ )
+ )
+ else:
+ raise TypeError(f"Don't know how to queue workload of type
{type(workload).__name__}")
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
"""
diff --git a/providers/edge3/src/airflow/providers/edge3/models/types.py
b/providers/edge3/src/airflow/providers/edge3/models/types.py
new file mode 100644
index 00000000000..19cea39539d
--- /dev/null
+++ b/providers/edge3/src/airflow/providers/edge3/models/types.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, TypeAlias, TypeGuard
+
+from airflow.executors.workloads import ExecuteTask
+from airflow.providers.edge3.version_compat import AIRFLOW_V_3_3_PLUS
+
+if TYPE_CHECKING:
+ from airflow.executors import workloads
+ from airflow.executors.workloads import ExecuteCallback
+
+if not AIRFLOW_V_3_3_PLUS:
+ ExecuteTypeBody: TypeAlias = ExecuteTask
+else:
+ from airflow.executors.workloads import ExecutorWorkload
+
+ ExecuteTypeBody: TypeAlias = ExecutorWorkload # type:
ignore[no-redef,misc]
+
+
+def is_callback_execute(workload: workloads.All) -> TypeGuard[ExecuteCallback]:
+ if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import ExecuteCallback
+
+ return isinstance(workload, ExecuteCallback)
+ return False
+
+
+# This is the key used to identify execute_callback jobs.
+# Changing this value may break compatibility with existing data in the
edge_job table.
+EXECUTE_CALLBACK_TAG = "ExecuteCallback"
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
index 7edaa1ce5d2..fde62612df8 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
@@ -22,9 +22,9 @@ from typing import Annotated
from fastapi import Path
from pydantic import BaseModel, Field
-from airflow.executors.workloads import ExecuteTask # noqa: TCH001
from airflow.providers.common.compat.sdk import TaskInstanceKey
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState #
noqa: TCH001
+from airflow.providers.edge3.models.types import ExecuteTypeBody # noqa:
TCH001
class WorkerApiDocs:
@@ -69,7 +69,7 @@ class EdgeJobFetched(EdgeJobBase):
"""Job that is to be executed on the edge worker."""
command: Annotated[
- ExecuteTask,
+ ExecuteTypeBody,
Field(
title="Command",
description="Command line to use to execute the job in Airflow",
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
index 5b296c92f1d..c110e0ad07f 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
@@ -17,7 +17,7 @@
from __future__ import annotations
-from typing import Annotated
+from typing import TYPE_CHECKING, Annotated
from fastapi import Body, Depends, status
from sqlalchemy import select, update
@@ -28,6 +28,7 @@ from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_
from airflow.executors.workloads import ExecuteTask
from airflow.providers.common.compat.sdk import Stats, timezone
from airflow.providers.edge3.models.edge_job import EdgeJobModel
+from airflow.providers.edge3.version_compat import AIRFLOW_V_3_3_PLUS
from airflow.providers.edge3.worker_api.auth import
jwt_token_authorization_rest
from airflow.providers.edge3.worker_api.datamodels import (
EdgeJobFetched,
@@ -36,10 +37,20 @@ from airflow.providers.edge3.worker_api.datamodels import (
)
from airflow.utils.state import TaskInstanceState
+if TYPE_CHECKING:
+ from airflow.providers.edge3.models.types import ExecuteTypeBody
+
jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
-def parse_command(command: str) -> ExecuteTask:
+def parse_command(command: str, dag_id: str, run_id: str) -> ExecuteTypeBody:
+ if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import ExecuteCallback
+ from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
+
+ if dag_id == EXECUTE_CALLBACK_TAG and
run_id.startswith(EXECUTE_CALLBACK_TAG):
+ return ExecuteCallback.model_validate_json(command) # type:
ignore[return-value]
+
return ExecuteTask.model_validate_json(command)
@@ -94,7 +105,7 @@ def fetch(
run_id=job.run_id,
map_index=job.map_index,
try_number=job.try_number,
- command=parse_command(job.command),
+ command=parse_command(job.command, job.dag_id, job.run_id),
concurrency_slots=job.concurrency_slots,
)
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 6b006a92916..ae26fc79a22 100644
---
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -964,6 +964,32 @@ components:
- name
title: BundleInfo
description: Schema for telling task which bundle to run with.
+ CallbackDTO:
+ properties:
+ id:
+ type: string
+ title: Id
+ fetch_method:
+ $ref: '#/components/schemas/CallbackFetchMethod'
+ data:
+ additionalProperties: true
+ type: object
+ title: Data
+ type: object
+ required:
+ - id
+ - fetch_method
+ - data
+ title: CallbackDTO
+ description: Schema for Callback with minimal required fields needed for
Executors
+ and Task SDK.
+ CallbackFetchMethod:
+ type: string
+ enum:
+ - dag_attribute
+ - import_path
+ title: CallbackFetchMethod
+ description: Methods used to fetch callback at runtime.
ConcurrencyRequest:
properties:
concurrency:
@@ -1000,9 +1026,16 @@ components:
title: Try Number
description: The number of attempt to execute this task.
command:
- $ref: '#/components/schemas/ExecuteTask'
+ oneOf:
+ - $ref: '#/components/schemas/ExecuteTask'
+ - $ref: '#/components/schemas/ExecuteCallback'
title: Command
description: Command line to use to execute the job in Airflow
+ discriminator:
+ propertyName: type
+ mapping:
+ ExecuteCallback: '#/components/schemas/ExecuteCallback'
+ ExecuteTask: '#/components/schemas/ExecuteTask'
concurrency_slots:
type: integer
title: Concurrency Slots
@@ -1035,6 +1068,38 @@ components:
- offline maintenance
title: EdgeWorkerState
description: Status of a Edge Worker instance.
+ ExecuteCallback:
+ properties:
+ token:
+ type: string
+ title: Token
+ dag_rel_path:
+ type: string
+ format: path
+ title: Dag Rel Path
+ bundle_info:
+ $ref: '#/components/schemas/BundleInfo'
+ log_path:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Log Path
+ callback:
+ $ref: '#/components/schemas/CallbackDTO'
+ type:
+ type: string
+ const: ExecuteCallback
+ title: Type
+ default: ExecuteCallback
+ type: object
+ required:
+ - token
+ - dag_rel_path
+ - bundle_info
+ - log_path
+ - callback
+ title: ExecuteCallback
+ description: Execute the given Callback.
ExecuteTask:
properties:
token:
diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py
b/providers/edge3/tests/unit/edge3/cli/test_worker.py
index 245c7431cd4..23a247d8a08 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py
@@ -28,6 +28,7 @@ from datetime import datetime
from io import StringIO
from multiprocessing import Process
from pathlib import Path
+from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import call, patch
@@ -48,6 +49,7 @@ from airflow.providers.edge3.models.edge_worker import (
EdgeWorkerState,
EdgeWorkerVersionException,
)
+from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
from airflow.providers.edge3.worker_api.datamodels import (
EdgeJobFetched,
WorkerRegistrationReturn,
@@ -58,6 +60,12 @@ from airflow.utils.state import TaskInstanceState
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS,
AIRFLOW_V_3_3_PLUS
+if TYPE_CHECKING:
+ from airflow.executors.workloads import ExecuteCallback
+
+if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import ExecuteCallback
+
pytest.importorskip("pydantic", minversion="2.0.0")
pytestmark = [pytest.mark.asyncio]
@@ -80,6 +88,23 @@ MOCK_COMMAND = {
"dag_rel_path": "mock.py",
"log_path": "mock.log",
"bundle_info": {"name": "hello", "version": "abc"},
+ "type": "ExecuteTask",
+}
+
+MOCK_CALLBACK_COMMAND = {
+ "token": "mock",
+ "callback": {
+ "id": "12345678-1234-5678-1234-567812345678",
+ "fetch_method": "import_path",
+ "data": {
+ "path": "builtins.dict",
+ "kwargs": {"a": 1, "b": 2, "c": 3},
+ },
+ },
+ "dag_rel_path": "test.py",
+ "log_path": "test.log",
+ "bundle_info": {"name": "test_bundle", "version": "1.0"},
+ "type": "ExecuteCallback",
}
@@ -1150,3 +1175,25 @@ class TestSignalHandling:
with mock.patch("os.kill", side_effect=ProcessLookupError):
worker_with_one_job.shutdown_handler()
assert worker_with_one_job.drain is True
+
+
+class TestEdgeJobFetchedSerialization:
+ """Test that EdgeJobFetched serializes and deserializes with both
ExecuteTask and ExecuteCallback."""
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The tests should be
skipped for Airflow < 3.3")
+ def test_serialize_with_execute_callback(self):
+ fetched = EdgeJobFetched(
+ dag_id=EXECUTE_CALLBACK_TAG,
+ task_id="12345678-1234-5678-1234-567812345678",
+
run_id=f"{EXECUTE_CALLBACK_TAG}-12345678-1234-5678-1234-567812345678",
+ map_index=-1,
+ try_number=0,
+ concurrency_slots=1,
+ command=MOCK_CALLBACK_COMMAND, # type: ignore[arg-type]
+ )
+ serialized = fetched.model_dump_json()
+ deserialized = EdgeJobFetched(**json.loads(serialized))
+
+ assert deserialized.dag_id == EXECUTE_CALLBACK_TAG
+ assert deserialized.command.type == EXECUTE_CALLBACK_TAG
+ assert isinstance(deserialized.command, ExecuteCallback)
diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
index 840aadb4caf..3846826c0a6 100644
--- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
+++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
@@ -19,22 +19,30 @@ from __future__ import annotations
import logging
import os
from datetime import datetime, timedelta
+from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, patch
+from uuid import uuid4
import pytest
import time_machine
from sqlalchemy import delete, select
+from airflow.executors.workloads import BundleInfo, ExecuteTask
from airflow.providers.common.compat.sdk import Stats, TaskInstanceKey, conf,
timezone
from airflow.providers.edge3.executors.edge_executor import EdgeExecutor
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel,
EdgeWorkerState
+from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
from airflow.utils.session import create_session
from airflow.utils.state import TaskInstanceState
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS,
AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import CallbackFetchMethod,
ExecuteCallback, TaskInstanceDTO
+ from airflow.executors.workloads.callback import CallbackDTO
pytestmark = pytest.mark.db_test
@@ -559,3 +567,121 @@ class TestEdgeExecutorMultiTeam:
with create_session() as session:
remaining_jobs = session.scalars(select(EdgeJobModel)).all()
assert len(remaining_jobs) == 2
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="ExecuteTypeBody union
requires Airflow 3.3+")
+class TestQueueWorkload:
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ with create_session() as session:
+ session.execute(delete(EdgeJobModel))
+ session.commit()
+
+ def _make_execute_task(self) -> ExecuteTask:
+ ti = TaskInstanceDTO(
+ id=uuid4(),
+ dag_version_id=uuid4(),
+ task_id="test_task",
+ dag_id="test_dag",
+ run_id="test_run",
+ try_number=1,
+ map_index=-1,
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
+ )
+ return ExecuteTask(
+ ti=ti,
+ dag_rel_path=Path("test_dag.py"),
+ token="test_token",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ log_path="test.log",
+ )
+
+ def test_queue_workload_execute_task(self):
+ executor = EdgeExecutor()
+ workload = self._make_execute_task()
+
+ executor.queue_workload(workload)
+
+ with create_session() as session:
+ job = session.scalar(select(EdgeJobModel))
+ assert job is not None
+ assert job.dag_id == "test_dag"
+ assert job.task_id == "test_task"
+ assert job.run_id == "test_run"
+ assert job.state == TaskInstanceState.QUEUED
+ assert '"type":"ExecuteTask"' in job.command or '"type":
"ExecuteTask"' in job.command
+
+ def test_queue_workload_execute_task_existing_job(self):
+ executor = EdgeExecutor()
+ workload = self._make_execute_task()
+
+ executor.queue_workload(workload)
+ executor.queue_workload(workload)
+
+ with create_session() as session:
+ jobs = session.scalars(select(EdgeJobModel)).all()
+ assert len(jobs) == 1
+ assert jobs[0].state == TaskInstanceState.QUEUED
+
+ def test_queue_workload_execute_callback(self):
+ executor = EdgeExecutor()
+ id = str(uuid4())
+ callback_data = CallbackDTO(
+ id=id,
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "builtins.dict",
+ "kwargs": {"a": 1, "b": 2, "c": 3},
+ },
+ )
+ workload = ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path=Path("test.py"),
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ executor.queue_workload(workload)
+
+ with create_session() as session:
+ job = session.scalar(select(EdgeJobModel))
+ assert job is not None
+ assert job.dag_id == EXECUTE_CALLBACK_TAG
+ assert job.task_id == id
+ assert job.run_id == f"{EXECUTE_CALLBACK_TAG}-{id}"
+ assert job.state == TaskInstanceState.QUEUED
+ assert '"type":"ExecuteCallback"' in job.command or '"type":
"ExecuteCallback"' in job.command
+
+ def test_queue_workload_execute_callback_existing_job(self):
+ executor = EdgeExecutor()
+ callback_data = CallbackDTO(
+ id=str(uuid4()),
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "builtins.dict",
+ "kwargs": {"a": 1, "b": 2, "c": 3},
+ },
+ )
+ workload = ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path=Path("test.py"),
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ executor.queue_workload(workload)
+ executor.queue_workload(workload)
+
+ with create_session() as session:
+ jobs = session.scalars(select(EdgeJobModel)).all()
+ assert len(jobs) == 1
+ assert jobs[0].state == TaskInstanceState.QUEUED
+
+ def test_queue_workload_unknown_type_raises(self):
+ executor = EdgeExecutor()
+ with pytest.raises(TypeError, match="Don't know how to queue
workload"):
+ executor.queue_workload(MagicMock(spec=[]))
diff --git a/providers/edge3/tests/unit/edge3/models/test_types.py
b/providers/edge3/tests/unit/edge3/models/test_types.py
new file mode 100644
index 00000000000..7e21ce280d4
--- /dev/null
+++ b/providers/edge3/tests/unit/edge3/models/test_types.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from pathlib import Path
+from uuid import uuid4
+
+import pytest
+from pydantic import TypeAdapter
+
+from airflow.executors.workloads import BundleInfo, ExecuteTask
+from airflow.providers.edge3.models.types import ExecuteTypeBody,
is_callback_execute
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import CallbackFetchMethod,
ExecuteCallback, TaskInstanceDTO
+ from airflow.executors.workloads.callback import CallbackDTO
+
+
+def _make_execute_task() -> ExecuteTask:
+ ti = TaskInstanceDTO(
+ id=uuid4(),
+ dag_version_id=uuid4(),
+ task_id="test_task",
+ dag_id="test_dag",
+ run_id="test_run",
+ try_number=1,
+ map_index=-1,
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
+ )
+ return ExecuteTask(
+ ti=ti,
+ dag_rel_path=Path("test_dag.py"),
+ token="test_token",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ log_path="test.log",
+ )
+
+
+def _make_execute_callback() -> ExecuteCallback:
+ callback_data = CallbackDTO(
+ id=str(uuid4()),
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "builtins.dict",
+ "kwargs": {"a": 1, "b": 2, "c": 3},
+ },
+ )
+ return ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path=Path("test.py"),
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="ExecuteTypeBody union
requires Airflow 3.3+")
+class TestIsCallbackExecute:
+ def test_returns_false_for_execute_task(self):
+ workload = _make_execute_task()
+ assert is_callback_execute(workload) is False
+
+ def test_returns_true_for_execute_callback(self):
+ workload = _make_execute_callback()
+ assert is_callback_execute(workload) is True
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="ExecuteTypeBody union
requires Airflow 3.3+")
+class TestExecuteTypeBody:
+ def setup_method(self):
+ self.adapter: TypeAdapter = TypeAdapter(ExecuteTypeBody)
+
+ def test_validate_execute_task_json(self):
+ workload = _make_execute_task()
+ json_str = workload.model_dump_json()
+
+ result = self.adapter.validate_json(json_str)
+
+ assert isinstance(result, ExecuteTask)
+ assert result.ti.dag_id == "test_dag"
+
+ def test_validate_execute_callback_json(self):
+ workload = _make_execute_callback()
+ json_str = workload.model_dump_json()
+
+ result = self.adapter.validate_json(json_str)
+
+ assert isinstance(result, ExecuteCallback)
+ assert result.callback.fetch_method == CallbackFetchMethod.IMPORT_PATH
+
+ def test_validate_execute_task_dict(self):
+ workload = _make_execute_task()
+ data = workload.model_dump()
+
+ result = self.adapter.validate_python(data)
+
+ assert isinstance(result, ExecuteTask)
+
+ def test_validate_execute_callback_dict(self):
+ workload = _make_execute_callback()
+ data = workload.model_dump()
+
+ result = self.adapter.validate_python(data)
+
+ assert isinstance(result, ExecuteCallback)
+
+ def test_roundtrip_execute_task(self):
+ original = _make_execute_task()
+ json_str = self.adapter.dump_json(original)
+ restored = self.adapter.validate_json(json_str)
+
+ assert isinstance(restored, ExecuteTask)
+ assert restored.ti.task_id == original.ti.task_id
+ assert restored.ti.dag_id == original.ti.dag_id
+
+ def test_roundtrip_execute_callback(self):
+ original = _make_execute_callback()
+ json_str = self.adapter.dump_json(original)
+ restored = self.adapter.validate_json(json_str)
+
+ assert isinstance(restored, ExecuteCallback)
+ assert restored.callback.id == original.callback.id
diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
index 22c63d12cea..c120e86f088 100644
--- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
+++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
@@ -17,22 +17,34 @@
from __future__ import annotations
import json
+from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import patch
+from uuid import uuid4
import pytest
from sqlalchemy import delete, select
+from airflow.executors.workloads import BundleInfo, ExecuteTask
from airflow.providers.common.compat.sdk import Stats
from airflow.providers.edge3.models.edge_job import EdgeJobModel
+from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
from airflow.providers.edge3.worker_api.datamodels import WorkerQueuesBody
-from airflow.providers.edge3.worker_api.routes.jobs import fetch, state
+from airflow.providers.edge3.worker_api.routes.jobs import fetch,
parse_command, state
from airflow.utils.session import create_session
from airflow.utils.state import TaskInstanceState
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
if TYPE_CHECKING:
from sqlalchemy.orm import Session
+ from airflow.executors.workloads import ExecuteCallback
+
+if AIRFLOW_V_3_3_PLUS:
+ from airflow.executors.workloads import CallbackFetchMethod,
ExecuteCallback, TaskInstanceDTO
+ from airflow.executors.workloads.callback import CallbackDTO
+
pytestmark = pytest.mark.db_test
DAG_ID = "my_dag"
@@ -198,3 +210,76 @@ class TestJobsApiRoutes:
assert result3 is None
fetched_dag_ids = {result1.dag_id, result2.dag_id}
assert fetched_dag_ids == {"dag_a", "dag_b"}
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="The tests should be
skipped for Airflow < 3.3")
+class TestParseCommand:
+ def _make_execute_task(self) -> ExecuteTask:
+ ti = TaskInstanceDTO(
+ id=uuid4(),
+ dag_version_id=uuid4(),
+ task_id="test_task",
+ dag_id="test_dag",
+ run_id="test_run",
+ try_number=1,
+ map_index=-1,
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
+ )
+ return ExecuteTask(
+ ti=ti,
+ dag_rel_path=Path("test_dag.py"),
+ token="test_token",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ log_path="test.log",
+ )
+
+ def _make_execute_callback(self) -> ExecuteCallback:
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "builtins.dict",
+ "kwargs": {"a": 1, "b": 2, "c": 3},
+ },
+ )
+ return ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path=Path("test.py"),
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ def test_parse_command_execute_task(self):
+ workload = self._make_execute_task()
+ command_json = workload.model_dump_json()
+
+ result = parse_command(command_json, dag_id="test_dag",
run_id="test_run")
+
+ assert isinstance(result, ExecuteTask)
+ assert result.ti.dag_id == "test_dag"
+ assert result.ti.task_id == "test_task"
+
+ def test_parse_command_execute_callback(self):
+ workload = self._make_execute_callback()
+ command_json = workload.model_dump_json()
+
+ dag_id = EXECUTE_CALLBACK_TAG
+ run_id = f"{EXECUTE_CALLBACK_TAG}-{workload.callback.key}"
+
+ result = parse_command(command_json, dag_id=dag_id, run_id=run_id)
+
+ assert isinstance(result, ExecuteCallback)
+ assert result.callback.id == "12345678-1234-5678-1234-567812345678"
+ assert result.callback.fetch_method == CallbackFetchMethod.IMPORT_PATH
+
+ def test_parse_command_non_callback_dag_id_returns_execute_task(self):
+ """Even if run_id starts with ExecuteCallback, dag_id must also
match."""
+ workload = self._make_execute_task()
+ command_json = workload.model_dump_json()
+
+ result = parse_command(command_json, dag_id="some_dag",
run_id=f"{EXECUTE_CALLBACK_TAG}-something")
+
+ assert isinstance(result, ExecuteTask)