This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8607e9465b6 Fix edge executor to support handling execute callback 
workload (#67679)
8607e9465b6 is described below

commit 8607e9465b64e213855f3f9f4a9351e06914f975
Author: Jeongwoo Do <[email protected]>
AuthorDate: Tue Jun 2 02:27:57 2026 +0900

    Fix edge executor to support handling execute callback workload (#67679)
    
    * Fix edge executor to support handling execute callback workload
    
    * fix logic
    
    * fix logic
---
 providers/edge3/docs/edge_executor.rst             |  33 +++++
 .../src/airflow/providers/edge3/cli/worker.py      |  20 +--
 .../providers/edge3/executors/edge_executor.py     | 104 +++++++++------
 .../src/airflow/providers/edge3/models/types.py    |  46 +++++++
 .../providers/edge3/worker_api/datamodels.py       |   4 +-
 .../providers/edge3/worker_api/routes/jobs.py      |  17 ++-
 .../edge3/worker_api/v2-edge-generated.yaml        |  67 +++++++++-
 .../edge3/tests/unit/edge3/cli/test_worker.py      |  47 +++++++
 .../unit/edge3/executors/test_edge_executor.py     | 128 ++++++++++++++++++-
 .../edge3/tests/unit/edge3/models/test_types.py    | 140 +++++++++++++++++++++
 .../unit/edge3/worker_api/routes/test_jobs.py      |  87 ++++++++++++-
 11 files changed, 640 insertions(+), 53 deletions(-)

diff --git a/providers/edge3/docs/edge_executor.rst 
b/providers/edge3/docs/edge_executor.rst
index dd5041375c1..4452b869589 100644
--- a/providers/edge3/docs/edge_executor.rst
+++ b/providers/edge3/docs/edge_executor.rst
@@ -169,6 +169,39 @@ Here is an example setting pool_slots for a task:
 
         task_with_template()
 
+
+.. _edge_executor:execute_callback:
+
+Support ExecuteCallback in Worker
+---------------------------------
+
+In addition to executing tasks, the EdgeExecutor can also dispatch 
executor-level
+callbacks (``ExecuteCallback`` workloads, e.g. deadline callbacks) to edge 
workers.
+When the scheduler hands an ``ExecuteCallback`` to 
``EdgeExecutor.queue_workload``,
+it is enqueued into the same job queue (``EdgeJobModel``) that is used for task
+workloads, so an edge worker picks it up alongside regular tasks without any
+additional configuration.
+
+Callback jobs share the ``EdgeJobModel`` table with task jobs. They are
+distinguished by reserved values in the identifier columns:
+
+- ``dag_id`` is set to the constant tag ``ExecuteCallback``.
+- ``task_id`` is set to the callback key (the callback ID).
+- ``run_id`` is set to ``ExecuteCallback-<callback_key>``.
+- ``map_index`` is fixed to ``-1`` and ``try_number`` to ``0``.
+
+When the worker fetches such a job through the worker API, the command payload 
is
+deserialized back into an ``ExecuteCallback`` workload (instead of an
+``ExecuteTask``) based on these identifiers. On Airflow 3.3+, the worker 
executes
+both task and callback workloads via ``BaseExecutor.run_workload`` (or the
+``airflow.sdk.execution_time.execute_workload`` entrypoint when using the 
subprocess path).
+
+.. note::
+
+    This feature is only active on Airflow 3.3 or newer. On earlier Airflow 
versions
+    the EdgeExecutor only handles ``ExecuteTask`` workloads and any
+    ``ExecuteCallback`` will be rejected with a ``TypeError``.
+
 Current Limitations Edge Executor
 ---------------------------------
 
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py 
b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
index 1054e688e76..1a044fdde2e 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
@@ -67,7 +67,7 @@ from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.configuration import AirflowConfigParser
-    from airflow.executors.workloads import ExecuteTask
+    from airflow.providers.edge3.models.types import ExecuteTypeBody
     from airflow.providers.edge3.worker_api.datamodels import EdgeJobFetched
 
 logger = logging.getLogger(__name__)
@@ -422,7 +422,7 @@ class EdgeWorker:
             return EdgeWorkerState.MAINTENANCE_MODE
         return EdgeWorkerState.IDLE
 
-    def _run_job_via_supervisor(self, workload: ExecuteTask, error_file_path: 
Path) -> int:
+    def _run_job_via_supervisor(self, workload: ExecuteTypeBody, 
error_file_path: Path) -> int:
         """Run a task by calling the supervisor directly (executes inside a 
forked child process)."""
         _reset_parent_signal_state()
 
@@ -465,7 +465,7 @@ class EdgeWorker:
                 error_file_path.write_text(traceback.format_exc())
             return 1
 
-    def _launch_job_subprocess(self, workload: ExecuteTask) -> 
tuple[subprocess.Popen, Path]:
+    def _launch_job_subprocess(self, workload: ExecuteTypeBody) -> 
tuple[subprocess.Popen, Path]:
         """Launch workload via a fresh Python interpreter 
(subprocess.Popen)."""
         env = os.environ.copy()
         if self._execution_api_server_url:
@@ -500,11 +500,11 @@ class EdgeWorker:
         logger.info(
             "Launched task subprocess pid=%d for %s",
             process.pid,
-            workload.ti.id,
+            workload.display_name if AIRFLOW_V_3_3_PLUS else workload.ti.id,
         )
         return process, stderr_file_path
 
-    def _launch_job_fork(self, workload: ExecuteTask) -> tuple[Process, Path]:
+    def _launch_job_fork(self, workload: ExecuteTypeBody) -> tuple[Process, 
Path]:
         """Launch workload by forking the current process 
(multiprocessing.Process)."""
         # Improvement: Use frozen GC to prevent child process from copying 
unnecessary memory
         # See _spawn_workers_with_gc_freeze() in 
airflow-core/src/airflow/executors/local_executor.py
@@ -515,10 +515,14 @@ class EdgeWorker:
             kwargs={"workload": workload, "error_file_path": error_file_path},
         )
         process.start()
-        logger.info("Launched task fork pid=%d for %s", process.pid, 
workload.ti.id)
+        logger.info(
+            "Launched task fork pid=%d for %s",
+            process.pid,
+            workload.display_name if AIRFLOW_V_3_3_PLUS else workload.ti.id,
+        )
         return process, error_file_path
 
-    def _launch_job(self, edge_job: EdgeJobFetched, workload: ExecuteTask, 
logfile: Path) -> Job:
+    def _launch_job(self, edge_job: EdgeJobFetched, workload: ExecuteTypeBody, 
logfile: Path) -> Job:
         """
         Launch a task process.
 
@@ -673,7 +677,7 @@ class EdgeWorker:
 
         logger.info("Received job: %s", edge_job.identifier)
 
-        workload: ExecuteTask = edge_job.command
+        workload: ExecuteTypeBody = edge_job.command
         if TYPE_CHECKING:
             assert workload.log_path  # We need to assume this is defined in 
here
         logfile = Path(self.base_log_folder, workload.log_path)
diff --git 
a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py 
b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
index 99d79ffb519..9ac25874c3c 100644
--- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
+++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
@@ -33,6 +33,7 @@ from airflow.providers.edge3.models.db import EdgeDBManager, 
check_db_manager_co
 from airflow.providers.edge3.models.edge_job import EdgeJobModel
 from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
 from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, 
EdgeWorkerState, reset_metrics
+from airflow.providers.edge3.models.types import is_callback_execute
 from airflow.utils.db import DBLocks, create_global_lock
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import TaskInstanceState
@@ -101,44 +102,73 @@ class EdgeExecutor(BaseExecutor):
         session: Session,
     ) -> None:
         """Put new workload to queue. Airflow 3 entry point to execute a 
task."""
-        if not isinstance(workload, workloads.ExecuteTask):
-            raise TypeError(f"Don't know how to queue workload of type 
{type(workload).__name__}")
-
-        task_instance = workload.ti
-        key = task_instance.key
-
-        # Check if job already exists with same dag_id, task_id, run_id, 
map_index, try_number
-        existing_job = session.scalars(
-            select(EdgeJobModel).where(
-                EdgeJobModel.dag_id == key.dag_id,
-                EdgeJobModel.task_id == key.task_id,
-                EdgeJobModel.run_id == key.run_id,
-                EdgeJobModel.map_index == key.map_index,
-                EdgeJobModel.try_number == key.try_number,
-            )
-        ).first()
-
-        if existing_job:
-            existing_job.state = TaskInstanceState.QUEUED
-            existing_job.queue = task_instance.queue
-            existing_job.concurrency_slots = task_instance.pool_slots
-            existing_job.command = workload.model_dump_json()
-            existing_job.team_name = self.team_name
-        else:
-            session.add(
-                EdgeJobModel(
-                    dag_id=key.dag_id,
-                    task_id=key.task_id,
-                    run_id=key.run_id,
-                    map_index=key.map_index,
-                    try_number=key.try_number,
-                    state=TaskInstanceState.QUEUED,
-                    queue=task_instance.queue,
-                    concurrency_slots=task_instance.pool_slots,
-                    command=workload.model_dump_json(),
-                    team_name=self.team_name,
+        if is_callback_execute(workload):
+            from airflow.providers.edge3.models.types import 
EXECUTE_CALLBACK_TAG
+
+            existing_job = session.scalars(
+                select(EdgeJobModel).where(
+                    EdgeJobModel.dag_id == EXECUTE_CALLBACK_TAG,
+                    EdgeJobModel.task_id == workload.callback.id,
+                    EdgeJobModel.run_id == 
f"{EXECUTE_CALLBACK_TAG}-{workload.callback.id}",
                 )
-            )
+            ).first()
+
+            if existing_job:
+                existing_job.state = TaskInstanceState.QUEUED
+                existing_job.command = workload.model_dump_json()
+            else:
+                session.add(
+                    EdgeJobModel(
+                        dag_id=EXECUTE_CALLBACK_TAG,
+                        task_id=str(workload.callback.id),
+                        
run_id=f"{EXECUTE_CALLBACK_TAG}-{workload.callback.id}",
+                        map_index=-1,
+                        try_number=0,
+                        queue=self.conf.get_mandatory_value("operators", 
"default_queue"),
+                        concurrency_slots=1,
+                        state=TaskInstanceState.QUEUED,
+                        command=workload.model_dump_json(),
+                        team_name=self.team_name,
+                    )
+                )
+        elif isinstance(workload, workloads.ExecuteTask):
+            task_instance = workload.ti
+            key = task_instance.key
+
+            # Check if job already exists with same dag_id, task_id, run_id, 
map_index, try_number
+            existing_job = session.scalars(
+                select(EdgeJobModel).where(
+                    EdgeJobModel.dag_id == key.dag_id,
+                    EdgeJobModel.task_id == key.task_id,
+                    EdgeJobModel.run_id == key.run_id,
+                    EdgeJobModel.map_index == key.map_index,
+                    EdgeJobModel.try_number == key.try_number,
+                )
+            ).first()
+
+            if existing_job:
+                existing_job.state = TaskInstanceState.QUEUED
+                existing_job.queue = task_instance.queue
+                existing_job.concurrency_slots = task_instance.pool_slots
+                existing_job.command = workload.model_dump_json()
+                existing_job.team_name = self.team_name
+            else:
+                session.add(
+                    EdgeJobModel(
+                        dag_id=key.dag_id,
+                        task_id=key.task_id,
+                        run_id=key.run_id,
+                        map_index=key.map_index,
+                        try_number=key.try_number,
+                        state=TaskInstanceState.QUEUED,
+                        queue=task_instance.queue,
+                        concurrency_slots=task_instance.pool_slots,
+                        command=workload.model_dump_json(),
+                        team_name=self.team_name,
+                    )
+                )
+        else:
+            raise TypeError(f"Don't know how to queue workload of type 
{type(workload).__name__}")
 
     def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
         """
diff --git a/providers/edge3/src/airflow/providers/edge3/models/types.py 
b/providers/edge3/src/airflow/providers/edge3/models/types.py
new file mode 100644
index 00000000000..19cea39539d
--- /dev/null
+++ b/providers/edge3/src/airflow/providers/edge3/models/types.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, TypeAlias, TypeGuard
+
+from airflow.executors.workloads import ExecuteTask
+from airflow.providers.edge3.version_compat import AIRFLOW_V_3_3_PLUS
+
+if TYPE_CHECKING:
+    from airflow.executors import workloads
+    from airflow.executors.workloads import ExecuteCallback
+
+if not AIRFLOW_V_3_3_PLUS:
+    ExecuteTypeBody: TypeAlias = ExecuteTask
+else:
+    from airflow.executors.workloads import ExecutorWorkload
+
+    ExecuteTypeBody: TypeAlias = ExecutorWorkload  # type: 
ignore[no-redef,misc]
+
+
+def is_callback_execute(workload: workloads.All) -> TypeGuard[ExecuteCallback]:
+    if AIRFLOW_V_3_3_PLUS:
+        from airflow.executors.workloads import ExecuteCallback
+
+        return isinstance(workload, ExecuteCallback)
+    return False
+
+
+# This is the key used to identify execute_callback jobs.
+# Changing this value may break compatibility with existing data in the 
edge_job table.
+EXECUTE_CALLBACK_TAG = "ExecuteCallback"
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
index 7edaa1ce5d2..fde62612df8 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
@@ -22,9 +22,9 @@ from typing import Annotated
 from fastapi import Path
 from pydantic import BaseModel, Field
 
-from airflow.executors.workloads import ExecuteTask  # noqa: TCH001
 from airflow.providers.common.compat.sdk import TaskInstanceKey
 from airflow.providers.edge3.models.edge_worker import EdgeWorkerState  # 
noqa: TCH001
+from airflow.providers.edge3.models.types import ExecuteTypeBody  # noqa: 
TCH001
 
 
 class WorkerApiDocs:
@@ -69,7 +69,7 @@ class EdgeJobFetched(EdgeJobBase):
     """Job that is to be executed on the edge worker."""
 
     command: Annotated[
-        ExecuteTask,
+        ExecuteTypeBody,
         Field(
             title="Command",
             description="Command line to use to execute the job in Airflow",
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
index 5b296c92f1d..c110e0ad07f 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
@@ -17,7 +17,7 @@
 
 from __future__ import annotations
 
-from typing import Annotated
+from typing import TYPE_CHECKING, Annotated
 
 from fastapi import Body, Depends, status
 from sqlalchemy import select, update
@@ -28,6 +28,7 @@ from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_
 from airflow.executors.workloads import ExecuteTask
 from airflow.providers.common.compat.sdk import Stats, timezone
 from airflow.providers.edge3.models.edge_job import EdgeJobModel
+from airflow.providers.edge3.version_compat import AIRFLOW_V_3_3_PLUS
 from airflow.providers.edge3.worker_api.auth import 
jwt_token_authorization_rest
 from airflow.providers.edge3.worker_api.datamodels import (
     EdgeJobFetched,
@@ -36,10 +37,20 @@ from airflow.providers.edge3.worker_api.datamodels import (
 )
 from airflow.utils.state import TaskInstanceState
 
+if TYPE_CHECKING:
+    from airflow.providers.edge3.models.types import ExecuteTypeBody
+
 jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
 
 
-def parse_command(command: str) -> ExecuteTask:
+def parse_command(command: str, dag_id: str, run_id: str) -> ExecuteTypeBody:
+    if AIRFLOW_V_3_3_PLUS:
+        from airflow.executors.workloads import ExecuteCallback
+        from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
+
+        if dag_id == EXECUTE_CALLBACK_TAG and 
run_id.startswith(EXECUTE_CALLBACK_TAG):
+            return ExecuteCallback.model_validate_json(command)  # type: 
ignore[return-value]
+
     return ExecuteTask.model_validate_json(command)
 
 
@@ -94,7 +105,7 @@ def fetch(
         run_id=job.run_id,
         map_index=job.map_index,
         try_number=job.try_number,
-        command=parse_command(job.command),
+        command=parse_command(job.command, job.dag_id, job.run_id),
         concurrency_slots=job.concurrency_slots,
     )
 
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml 
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 6b006a92916..ae26fc79a22 100644
--- 
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++ 
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -964,6 +964,32 @@ components:
       - name
       title: BundleInfo
       description: Schema for telling task which bundle to run with.
+    CallbackDTO:
+      properties:
+        id:
+          type: string
+          title: Id
+        fetch_method:
+          $ref: '#/components/schemas/CallbackFetchMethod'
+        data:
+          additionalProperties: true
+          type: object
+          title: Data
+      type: object
+      required:
+      - id
+      - fetch_method
+      - data
+      title: CallbackDTO
+      description: Schema for Callback with minimal required fields needed for 
Executors
+        and Task SDK.
+    CallbackFetchMethod:
+      type: string
+      enum:
+      - dag_attribute
+      - import_path
+      title: CallbackFetchMethod
+      description: Methods used to fetch callback at runtime.
     ConcurrencyRequest:
       properties:
         concurrency:
@@ -1000,9 +1026,16 @@ components:
           title: Try Number
           description: The number of attempt to execute this task.
         command:
-          $ref: '#/components/schemas/ExecuteTask'
+          oneOf:
+          - $ref: '#/components/schemas/ExecuteTask'
+          - $ref: '#/components/schemas/ExecuteCallback'
           title: Command
           description: Command line to use to execute the job in Airflow
+          discriminator:
+            propertyName: type
+            mapping:
+              ExecuteCallback: '#/components/schemas/ExecuteCallback'
+              ExecuteTask: '#/components/schemas/ExecuteTask'
         concurrency_slots:
           type: integer
           title: Concurrency Slots
@@ -1035,6 +1068,38 @@ components:
       - offline maintenance
       title: EdgeWorkerState
       description: Status of a Edge Worker instance.
+    ExecuteCallback:
+      properties:
+        token:
+          type: string
+          title: Token
+        dag_rel_path:
+          type: string
+          format: path
+          title: Dag Rel Path
+        bundle_info:
+          $ref: '#/components/schemas/BundleInfo'
+        log_path:
+          anyOf:
+          - type: string
+          - type: 'null'
+          title: Log Path
+        callback:
+          $ref: '#/components/schemas/CallbackDTO'
+        type:
+          type: string
+          const: ExecuteCallback
+          title: Type
+          default: ExecuteCallback
+      type: object
+      required:
+      - token
+      - dag_rel_path
+      - bundle_info
+      - log_path
+      - callback
+      title: ExecuteCallback
+      description: Execute the given Callback.
     ExecuteTask:
       properties:
         token:
diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py 
b/providers/edge3/tests/unit/edge3/cli/test_worker.py
index 245c7431cd4..23a247d8a08 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py
@@ -28,6 +28,7 @@ from datetime import datetime
 from io import StringIO
 from multiprocessing import Process
 from pathlib import Path
+from typing import TYPE_CHECKING
 from unittest import mock
 from unittest.mock import call, patch
 
@@ -48,6 +49,7 @@ from airflow.providers.edge3.models.edge_worker import (
     EdgeWorkerState,
     EdgeWorkerVersionException,
 )
+from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
 from airflow.providers.edge3.worker_api.datamodels import (
     EdgeJobFetched,
     WorkerRegistrationReturn,
@@ -58,6 +60,12 @@ from airflow.utils.state import TaskInstanceState
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS, 
AIRFLOW_V_3_3_PLUS
 
+if TYPE_CHECKING:
+    from airflow.executors.workloads import ExecuteCallback
+
+if AIRFLOW_V_3_3_PLUS:
+    from airflow.executors.workloads import ExecuteCallback
+
 pytest.importorskip("pydantic", minversion="2.0.0")
 pytestmark = [pytest.mark.asyncio]
 
@@ -80,6 +88,23 @@ MOCK_COMMAND = {
     "dag_rel_path": "mock.py",
     "log_path": "mock.log",
     "bundle_info": {"name": "hello", "version": "abc"},
+    "type": "ExecuteTask",
+}
+
+MOCK_CALLBACK_COMMAND = {
+    "token": "mock",
+    "callback": {
+        "id": "12345678-1234-5678-1234-567812345678",
+        "fetch_method": "import_path",
+        "data": {
+            "path": "builtins.dict",
+            "kwargs": {"a": 1, "b": 2, "c": 3},
+        },
+    },
+    "dag_rel_path": "test.py",
+    "log_path": "test.log",
+    "bundle_info": {"name": "test_bundle", "version": "1.0"},
+    "type": "ExecuteCallback",
 }
 
 
@@ -1150,3 +1175,25 @@ class TestSignalHandling:
         with mock.patch("os.kill", side_effect=ProcessLookupError):
             worker_with_one_job.shutdown_handler()
         assert worker_with_one_job.drain is True
+
+
+class TestEdgeJobFetchedSerialization:
+    """Test that EdgeJobFetched serializes and deserializes with both 
ExecuteTask and ExecuteCallback."""
+
+    @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The tests should be 
skipped for Airflow < 3.3")
+    def test_serialize_with_execute_callback(self):
+        fetched = EdgeJobFetched(
+            dag_id=EXECUTE_CALLBACK_TAG,
+            task_id="12345678-1234-5678-1234-567812345678",
+            
run_id=f"{EXECUTE_CALLBACK_TAG}-12345678-1234-5678-1234-567812345678",
+            map_index=-1,
+            try_number=0,
+            concurrency_slots=1,
+            command=MOCK_CALLBACK_COMMAND,  # type: ignore[arg-type]
+        )
+        serialized = fetched.model_dump_json()
+        deserialized = EdgeJobFetched(**json.loads(serialized))
+
+        assert deserialized.dag_id == EXECUTE_CALLBACK_TAG
+        assert deserialized.command.type == EXECUTE_CALLBACK_TAG
+        assert isinstance(deserialized.command, ExecuteCallback)
diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py 
b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
index 840aadb4caf..3846826c0a6 100644
--- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
+++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
@@ -19,22 +19,30 @@ from __future__ import annotations
 import logging
 import os
 from datetime import datetime, timedelta
+from pathlib import Path
 from unittest import mock
 from unittest.mock import MagicMock, patch
+from uuid import uuid4
 
 import pytest
 import time_machine
 from sqlalchemy import delete, select
 
+from airflow.executors.workloads import BundleInfo, ExecuteTask
 from airflow.providers.common.compat.sdk import Stats, TaskInstanceKey, conf, 
timezone
 from airflow.providers.edge3.executors.edge_executor import EdgeExecutor
 from airflow.providers.edge3.models.edge_job import EdgeJobModel
 from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, 
EdgeWorkerState
+from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
 from airflow.utils.session import create_session
 from airflow.utils.state import TaskInstanceState
 
 from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS, 
AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+    from airflow.executors.workloads import CallbackFetchMethod, 
ExecuteCallback, TaskInstanceDTO
+    from airflow.executors.workloads.callback import CallbackDTO
 
 pytestmark = pytest.mark.db_test
 
@@ -559,3 +567,121 @@ class TestEdgeExecutorMultiTeam:
         with create_session() as session:
             remaining_jobs = session.scalars(select(EdgeJobModel)).all()
             assert len(remaining_jobs) == 2
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="ExecuteTypeBody union 
requires Airflow 3.3+")
+class TestQueueWorkload:
+    @pytest.fixture(autouse=True)
+    def setup(self):
+        with create_session() as session:
+            session.execute(delete(EdgeJobModel))
+            session.commit()
+
+    def _make_execute_task(self) -> ExecuteTask:
+        ti = TaskInstanceDTO(
+            id=uuid4(),
+            dag_version_id=uuid4(),
+            task_id="test_task",
+            dag_id="test_dag",
+            run_id="test_run",
+            try_number=1,
+            map_index=-1,
+            pool_slots=1,
+            queue="default",
+            priority_weight=1,
+        )
+        return ExecuteTask(
+            ti=ti,
+            dag_rel_path=Path("test_dag.py"),
+            token="test_token",
+            bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+            log_path="test.log",
+        )
+
+    def test_queue_workload_execute_task(self):
+        executor = EdgeExecutor()
+        workload = self._make_execute_task()
+
+        executor.queue_workload(workload)
+
+        with create_session() as session:
+            job = session.scalar(select(EdgeJobModel))
+            assert job is not None
+            assert job.dag_id == "test_dag"
+            assert job.task_id == "test_task"
+            assert job.run_id == "test_run"
+            assert job.state == TaskInstanceState.QUEUED
+            assert '"type":"ExecuteTask"' in job.command or '"type": 
"ExecuteTask"' in job.command
+
+    def test_queue_workload_execute_task_existing_job(self):
+        executor = EdgeExecutor()
+        workload = self._make_execute_task()
+
+        executor.queue_workload(workload)
+        executor.queue_workload(workload)
+
+        with create_session() as session:
+            jobs = session.scalars(select(EdgeJobModel)).all()
+            assert len(jobs) == 1
+            assert jobs[0].state == TaskInstanceState.QUEUED
+
+    def test_queue_workload_execute_callback(self):
+        executor = EdgeExecutor()
+        id = str(uuid4())
+        callback_data = CallbackDTO(
+            id=id,
+            fetch_method=CallbackFetchMethod.IMPORT_PATH,
+            data={
+                "path": "builtins.dict",
+                "kwargs": {"a": 1, "b": 2, "c": 3},
+            },
+        )
+        workload = ExecuteCallback(
+            callback=callback_data,
+            dag_rel_path=Path("test.py"),
+            bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+            token="test_token",
+            log_path="test.log",
+        )
+
+        executor.queue_workload(workload)
+
+        with create_session() as session:
+            job = session.scalar(select(EdgeJobModel))
+            assert job is not None
+            assert job.dag_id == EXECUTE_CALLBACK_TAG
+            assert job.task_id == id
+            assert job.run_id == f"{EXECUTE_CALLBACK_TAG}-{id}"
+            assert job.state == TaskInstanceState.QUEUED
+            assert '"type":"ExecuteCallback"' in job.command or '"type": 
"ExecuteCallback"' in job.command
+
+    def test_queue_workload_execute_callback_existing_job(self):
+        executor = EdgeExecutor()
+        callback_data = CallbackDTO(
+            id=str(uuid4()),
+            fetch_method=CallbackFetchMethod.IMPORT_PATH,
+            data={
+                "path": "builtins.dict",
+                "kwargs": {"a": 1, "b": 2, "c": 3},
+            },
+        )
+        workload = ExecuteCallback(
+            callback=callback_data,
+            dag_rel_path=Path("test.py"),
+            bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+            token="test_token",
+            log_path="test.log",
+        )
+
+        executor.queue_workload(workload)
+        executor.queue_workload(workload)
+
+        with create_session() as session:
+            jobs = session.scalars(select(EdgeJobModel)).all()
+            assert len(jobs) == 1
+            assert jobs[0].state == TaskInstanceState.QUEUED
+
+    def test_queue_workload_unknown_type_raises(self):
+        executor = EdgeExecutor()
+        with pytest.raises(TypeError, match="Don't know how to queue 
workload"):
+            executor.queue_workload(MagicMock(spec=[]))
diff --git a/providers/edge3/tests/unit/edge3/models/test_types.py 
b/providers/edge3/tests/unit/edge3/models/test_types.py
new file mode 100644
index 00000000000..7e21ce280d4
--- /dev/null
+++ b/providers/edge3/tests/unit/edge3/models/test_types.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from pathlib import Path
+from uuid import uuid4
+
+import pytest
+from pydantic import TypeAdapter
+
+from airflow.executors.workloads import BundleInfo, ExecuteTask
+from airflow.providers.edge3.models.types import ExecuteTypeBody, 
is_callback_execute
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
+if AIRFLOW_V_3_3_PLUS:
+    from airflow.executors.workloads import CallbackFetchMethod, 
ExecuteCallback, TaskInstanceDTO
+    from airflow.executors.workloads.callback import CallbackDTO
+
+
+def _make_execute_task() -> ExecuteTask:
+    ti = TaskInstanceDTO(
+        id=uuid4(),
+        dag_version_id=uuid4(),
+        task_id="test_task",
+        dag_id="test_dag",
+        run_id="test_run",
+        try_number=1,
+        map_index=-1,
+        pool_slots=1,
+        queue="default",
+        priority_weight=1,
+    )
+    return ExecuteTask(
+        ti=ti,
+        dag_rel_path=Path("test_dag.py"),
+        token="test_token",
+        bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+        log_path="test.log",
+    )
+
+
+def _make_execute_callback() -> ExecuteCallback:
+    callback_data = CallbackDTO(
+        id=str(uuid4()),
+        fetch_method=CallbackFetchMethod.IMPORT_PATH,
+        data={
+            "path": "builtins.dict",
+            "kwargs": {"a": 1, "b": 2, "c": 3},
+        },
+    )
+    return ExecuteCallback(
+        callback=callback_data,
+        dag_rel_path=Path("test.py"),
+        bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+        token="test_token",
+        log_path="test.log",
+    )
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="ExecuteTypeBody union 
requires Airflow 3.3+")
+class TestIsCallbackExecute:
+    def test_returns_false_for_execute_task(self):
+        workload = _make_execute_task()
+        assert is_callback_execute(workload) is False
+
+    def test_returns_true_for_execute_callback(self):
+        workload = _make_execute_callback()
+        assert is_callback_execute(workload) is True
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="ExecuteTypeBody union 
requires Airflow 3.3+")
+class TestExecuteTypeBody:
+    def setup_method(self):
+        self.adapter: TypeAdapter = TypeAdapter(ExecuteTypeBody)
+
+    def test_validate_execute_task_json(self):
+        workload = _make_execute_task()
+        json_str = workload.model_dump_json()
+
+        result = self.adapter.validate_json(json_str)
+
+        assert isinstance(result, ExecuteTask)
+        assert result.ti.dag_id == "test_dag"
+
+    def test_validate_execute_callback_json(self):
+        workload = _make_execute_callback()
+        json_str = workload.model_dump_json()
+
+        result = self.adapter.validate_json(json_str)
+
+        assert isinstance(result, ExecuteCallback)
+        assert result.callback.fetch_method == CallbackFetchMethod.IMPORT_PATH
+
+    def test_validate_execute_task_dict(self):
+        workload = _make_execute_task()
+        data = workload.model_dump()
+
+        result = self.adapter.validate_python(data)
+
+        assert isinstance(result, ExecuteTask)
+
+    def test_validate_execute_callback_dict(self):
+        workload = _make_execute_callback()
+        data = workload.model_dump()
+
+        result = self.adapter.validate_python(data)
+
+        assert isinstance(result, ExecuteCallback)
+
+    def test_roundtrip_execute_task(self):
+        original = _make_execute_task()
+        json_str = self.adapter.dump_json(original)
+        restored = self.adapter.validate_json(json_str)
+
+        assert isinstance(restored, ExecuteTask)
+        assert restored.ti.task_id == original.ti.task_id
+        assert restored.ti.dag_id == original.ti.dag_id
+
+    def test_roundtrip_execute_callback(self):
+        original = _make_execute_callback()
+        json_str = self.adapter.dump_json(original)
+        restored = self.adapter.validate_json(json_str)
+
+        assert isinstance(restored, ExecuteCallback)
+        assert restored.callback.id == original.callback.id
diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py 
b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
index 22c63d12cea..c120e86f088 100644
--- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
+++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
@@ -17,22 +17,34 @@
 from __future__ import annotations
 
 import json
+from pathlib import Path
 from typing import TYPE_CHECKING
 from unittest.mock import patch
+from uuid import uuid4
 
 import pytest
 from sqlalchemy import delete, select
 
+from airflow.executors.workloads import BundleInfo, ExecuteTask
 from airflow.providers.common.compat.sdk import Stats
 from airflow.providers.edge3.models.edge_job import EdgeJobModel
+from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG
 from airflow.providers.edge3.worker_api.datamodels import WorkerQueuesBody
-from airflow.providers.edge3.worker_api.routes.jobs import fetch, state
+from airflow.providers.edge3.worker_api.routes.jobs import fetch, 
parse_command, state
 from airflow.utils.session import create_session
 from airflow.utils.state import TaskInstanceState
 
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
+    from airflow.executors.workloads import ExecuteCallback
+
+if AIRFLOW_V_3_3_PLUS:
+    from airflow.executors.workloads import CallbackFetchMethod, 
ExecuteCallback, TaskInstanceDTO
+    from airflow.executors.workloads.callback import CallbackDTO
+
 pytestmark = pytest.mark.db_test
 
 DAG_ID = "my_dag"
@@ -198,3 +210,76 @@ class TestJobsApiRoutes:
             assert result3 is None
             fetched_dag_ids = {result1.dag_id, result2.dag_id}
             assert fetched_dag_ids == {"dag_a", "dag_b"}
+
+
[email protected](not AIRFLOW_V_3_3_PLUS, reason="The tests should be 
skipped for Airflow < 3.3")
+class TestParseCommand:
+    def _make_execute_task(self) -> ExecuteTask:
+        ti = TaskInstanceDTO(
+            id=uuid4(),
+            dag_version_id=uuid4(),
+            task_id="test_task",
+            dag_id="test_dag",
+            run_id="test_run",
+            try_number=1,
+            map_index=-1,
+            pool_slots=1,
+            queue="default",
+            priority_weight=1,
+        )
+        return ExecuteTask(
+            ti=ti,
+            dag_rel_path=Path("test_dag.py"),
+            token="test_token",
+            bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+            log_path="test.log",
+        )
+
+    def _make_execute_callback(self) -> ExecuteCallback:
+        callback_data = CallbackDTO(
+            id="12345678-1234-5678-1234-567812345678",
+            fetch_method=CallbackFetchMethod.IMPORT_PATH,
+            data={
+                "path": "builtins.dict",
+                "kwargs": {"a": 1, "b": 2, "c": 3},
+            },
+        )
+        return ExecuteCallback(
+            callback=callback_data,
+            dag_rel_path=Path("test.py"),
+            bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+            token="test_token",
+            log_path="test.log",
+        )
+
+    def test_parse_command_execute_task(self):
+        workload = self._make_execute_task()
+        command_json = workload.model_dump_json()
+
+        result = parse_command(command_json, dag_id="test_dag", 
run_id="test_run")
+
+        assert isinstance(result, ExecuteTask)
+        assert result.ti.dag_id == "test_dag"
+        assert result.ti.task_id == "test_task"
+
+    def test_parse_command_execute_callback(self):
+        workload = self._make_execute_callback()
+        command_json = workload.model_dump_json()
+
+        dag_id = EXECUTE_CALLBACK_TAG
+        run_id = f"{EXECUTE_CALLBACK_TAG}-{workload.callback.key}"
+
+        result = parse_command(command_json, dag_id=dag_id, run_id=run_id)
+
+        assert isinstance(result, ExecuteCallback)
+        assert result.callback.id == "12345678-1234-5678-1234-567812345678"
+        assert result.callback.fetch_method == CallbackFetchMethod.IMPORT_PATH
+
+    def test_parse_command_non_callback_dag_id_returns_execute_task(self):
+        """Even if run_id starts with ExecuteCallback, dag_id must also 
match."""
+        workload = self._make_execute_task()
+        command_json = workload.model_dump_json()
+
+        result = parse_command(command_json, dag_id="some_dag", 
run_id=f"{EXECUTE_CALLBACK_TAG}-something")
+
+        assert isinstance(result, ExecuteTask)


Reply via email to