This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 161beebc771 Migrate Edge calls for Worker to FastAPI part 3 - Jobs 
routes (#44433)
161beebc771 is described below

commit 161beebc771329ad0525f4df39b46c6f72776034
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun Dec 1 12:01:29 2024 +0100

    Migrate Edge calls for Worker to FastAPI part 3 - Jobs routes (#44433)
    
    * Migrate Edge calls for Worker to FastAPI 3 - Jobs route
    
    * Review Feedback
    
    * Remove outdated type hints from review feedback
    
    * Update providers/src/airflow/providers/edge/worker_api/routes/jobs.py
    
    Co-authored-by: Copilot <[email protected]>
    
    * Add missing filter for free concurrency
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 providers/src/airflow/providers/edge/CHANGELOG.rst |   8 +
 providers/src/airflow/providers/edge/__init__.py   |   2 +-
 .../src/airflow/providers/edge/cli/api_client.py   |  34 ++-
 .../src/airflow/providers/edge/cli/edge_command.py |  21 +-
 .../providers/edge/openapi/edge_worker_api_v1.yaml | 240 +++++++++++++++++++++
 providers/src/airflow/providers/edge/provider.yaml |   2 +-
 .../src/airflow/providers/edge/worker_api/app.py   |   2 +
 .../providers/edge/worker_api/datamodels.py        |  72 ++++++-
 .../providers/edge/worker_api/routes/_v2_routes.py |  48 ++++-
 .../providers/edge/worker_api/routes/jobs.py       | 130 +++++++++++
 providers/tests/edge/cli/test_edge_command.py      |  24 +--
 11 files changed, 541 insertions(+), 42 deletions(-)

diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst 
b/providers/src/airflow/providers/edge/CHANGELOG.rst
index 93900dfeb50..05c3d728246 100644
--- a/providers/src/airflow/providers/edge/CHANGELOG.rst
+++ b/providers/src/airflow/providers/edge/CHANGELOG.rst
@@ -27,6 +27,14 @@
 Changelog
 ---------
 
+0.8.2pre0
+.........
+
+Misc
+~~~~
+
+* ``Migrate worker job calls to FastAPI.``
+
 0.8.1pre0
 .........
 
diff --git a/providers/src/airflow/providers/edge/__init__.py 
b/providers/src/airflow/providers/edge/__init__.py
index 8c53d0bed1b..d826c633ead 100644
--- a/providers/src/airflow/providers/edge/__init__.py
+++ b/providers/src/airflow/providers/edge/__init__.py
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
 
 __all__ = ["__version__"]
 
-__version__ = "0.8.1pre0"
+__version__ = "0.8.2pre0"
 
 if 
packaging.version.parse(packaging.version.parse(airflow_version).base_version) 
< packaging.version.parse(
     "2.10.0"
diff --git a/providers/src/airflow/providers/edge/cli/api_client.py 
b/providers/src/airflow/providers/edge/cli/api_client.py
index 9b5781e359d..c0a0144f5fe 100644
--- a/providers/src/airflow/providers/edge/cli/api_client.py
+++ b/providers/src/airflow/providers/edge/cli/api_client.py
@@ -32,7 +32,13 @@ from urllib3.exceptions import NewConnectionError
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.edge.worker_api.auth import jwt_signer
-from airflow.providers.edge.worker_api.datamodels import PushLogsBody, 
WorkerStateBody
+from airflow.providers.edge.worker_api.datamodels import (
+    EdgeJobFetched,
+    PushLogsBody,
+    WorkerQueuesBody,
+    WorkerStateBody,
+)
+from airflow.utils.state import TaskInstanceState  # noqa: TC001
 
 if TYPE_CHECKING:
     from airflow.models.taskinstancekey import TaskInstanceKey
@@ -114,6 +120,28 @@ def worker_set_state(
     )
 
 
+def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) 
-> EdgeJobFetched | None:
+    """Fetch a job to execute on the edge worker."""
+    result = _make_generic_request(
+        "GET",
+        f"jobs/fetch/{quote(hostname)}",
+        WorkerQueuesBody(queues=queues, 
free_concurrency=free_concurrency).model_dump_json(
+            exclude_unset=True
+        ),
+    )
+    if result:
+        return EdgeJobFetched(**result)
+    return None
+
+
+def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
+    """Set the state of a job."""
+    _make_generic_request(
+        "PATCH",
+        
f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
+    )
+
+
 def logs_logfile_path(task: TaskInstanceKey) -> Path:
     """Elaborate the path and filename to expect from task execution."""
     result = _make_generic_request(
@@ -133,5 +161,7 @@ def logs_push(
     _make_generic_request(
         "POST",
         
f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
-        PushLogsBody(log_chunk_time=log_chunk_time, 
log_chunk_data=log_chunk_data).model_dump_json(),
+        PushLogsBody(log_chunk_time=log_chunk_time, 
log_chunk_data=log_chunk_data).model_dump_json(
+            exclude_unset=True
+        ),
     )
diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py 
b/providers/src/airflow/providers/edge/cli/edge_command.py
index f175d6e77b3..d93a2269973 100644
--- a/providers/src/airflow/providers/edge/cli/edge_command.py
+++ b/providers/src/airflow/providers/edge/cli/edge_command.py
@@ -26,6 +26,7 @@ from datetime import datetime
 from pathlib import Path
 from subprocess import Popen
 from time import sleep
+from typing import TYPE_CHECKING
 
 import psutil
 from lockfile.pidlockfile import read_pid_from_pidfile, 
remove_existing_pidfile, write_pid_to_pidfile
@@ -37,18 +38,22 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.edge import __version__ as edge_provider_version
 from airflow.providers.edge.cli.api_client import (
+    jobs_fetch,
+    jobs_set_state,
     logs_logfile_path,
     logs_push,
     worker_register,
     worker_set_state,
 )
-from airflow.providers.edge.models.edge_job import EdgeJob
 from airflow.providers.edge.models.edge_worker import EdgeWorkerState, 
EdgeWorkerVersionException
 from airflow.utils import cli as cli_utils, timezone
 from airflow.utils.platform import IS_WINDOWS
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.state import TaskInstanceState
 
+if TYPE_CHECKING:
+    from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched
+
 logger = logging.getLogger(__name__)
 EDGE_WORKER_PROCESS_NAME = "edge-worker"
 EDGE_WORKER_HEADER = "\n".join(
@@ -81,7 +86,7 @@ def force_use_internal_api_on_edge_worker():
         if AIRFLOW_V_3_0_PLUS:
             # Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
             raise SystemExit(
-                "Error: EdgeWorker is currently broken on AIrflow 3/main due 
to removal of AIP-44, rework for AIP-72."
+                "Error: EdgeWorker is currently broken on Airflow 3/main due 
to removal of AIP-44, rework for AIP-72."
             )
 
         api_url = conf.get("edge", "api_url")
@@ -141,7 +146,7 @@ def _write_pid_to_pidfile(pid_file_path: str):
 class _Job:
     """Holds all information for a task/job to be executed as bundle."""
 
-    edge_job: EdgeJob
+    edge_job: EdgeJobFetched
     process: Popen
     logfile: Path
     logsize: int
@@ -240,9 +245,7 @@ class _EdgeWorkerCli:
     def fetch_job(self) -> bool:
         """Fetch and start a new job from central site."""
         logger.debug("Attempting to fetch a new job...")
-        edge_job = EdgeJob.reserve_task(
-            worker_name=self.hostname, free_concurrency=self.free_concurrency, 
queues=self.queues
-        )
+        edge_job = jobs_fetch(self.hostname, self.queues, 
self.free_concurrency)
         if edge_job:
             logger.info("Received job: %s", edge_job)
             env = os.environ.copy()
@@ -252,7 +255,7 @@ class _EdgeWorkerCli:
             process = Popen(edge_job.command, close_fds=True, env=env, 
start_new_session=True)
             logfile = logs_logfile_path(edge_job.key)
             self.jobs.append(_Job(edge_job, process, logfile, 0))
-            EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING)
+            jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
             return True
 
         logger.info("No new job to process%s", f", {len(self.jobs)} still 
running" if self.jobs else "")
@@ -268,10 +271,10 @@ class _EdgeWorkerCli:
                 self.jobs.remove(job)
                 if job.process.returncode == 0:
                     logger.info("Job completed: %s", job.edge_job)
-                    EdgeJob.set_state(job.edge_job.key, 
TaskInstanceState.SUCCESS)
+                    jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
                 else:
                     logger.error("Job failed: %s", job.edge_job)
-                    EdgeJob.set_state(job.edge_job.key, 
TaskInstanceState.FAILED)
+                    jobs_set_state(job.edge_job.key, TaskInstanceState.FAILED)
             else:
                 used_concurrency += job.edge_job.concurrency_slots
 
diff --git 
a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml 
b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
index 7915bdb5b4a..ef1f24a2288 100644
--- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
+++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
@@ -178,6 +178,161 @@ paths:
       summary: Register
       tags:
       - Worker
+  /jobs/fetch/{worker_name}:
+    get:
+      description: Fetch a job to execute on the edge worker.
+      x-openapi-router-controller: 
airflow.providers.edge.worker_api.routes._v2_routes
+      operationId: job_fetch_v2
+      parameters:
+      - in: path
+        name: worker_name
+        required: true
+        schema:
+          title: Worker Name
+          type: string
+      - description: JWT Authorization Token
+        in: header
+        name: authorization
+        required: true
+        schema:
+          description: JWT Authorization Token
+          title: Authorization
+          type: string
+      requestBody:
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/WorkerQueuesBody'
+              description: The worker remote has no access to log sink and 
with this
+                can send log chunks to the central site.
+              title: Log data chunks
+        required: true
+      responses:
+        '200':
+          content:
+            application/json:
+              schema:
+                anyOf:
+                - $ref: '#/components/schemas/EdgeJobFetched'
+                - type: object
+                  nullable: true
+                title: Response Fetch
+          description: Successful Response
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '422':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
+          description: Validation Error
+      summary: Fetch
+      tags:
+      - Jobs
+  /jobs/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}:
+    patch:
+      description: Update the state of a job running on the edge worker.
+      x-openapi-router-controller: 
airflow.providers.edge.worker_api.routes._v2_routes
+      operationId: job_state_v2
+      parameters:
+      - description: Identifier of the DAG to which the task belongs.
+        in: path
+        name: dag_id
+        required: true
+        schema:
+          description: Identifier of the DAG to which the task belongs.
+          title: Dag ID
+          type: string
+      - description: Task name in the DAG.
+        in: path
+        name: task_id
+        required: true
+        schema:
+          description: Task name in the DAG.
+          title: Task ID
+          type: string
+      - description: Run ID of the DAG execution.
+        in: path
+        name: run_id
+        required: true
+        schema:
+          description: Run ID of the DAG execution.
+          title: Run ID
+          type: string
+      - description: The number of attempt to execute this task.
+        in: path
+        name: try_number
+        required: true
+        schema:
+          description: The number of attempt to execute this task.
+          title: Try Number
+          type: integer
+      - description: For dynamically mapped tasks the mapping number, -1 if 
the task
+          is not mapped.
+        in: path
+        name: map_index
+        required: true
+        schema:
+          description: For dynamically mapped tasks the mapping number, -1 if 
the
+            task is not mapped.
+          title: Map Index
+          type: string  # This should be integer, but Connexion/Flask do not 
support negative integers in path parameters
+      - description: State of the assigned task under execution.
+        in: path
+        name: state
+        required: true
+        schema:
+          $ref: '#/components/schemas/TaskInstanceState'
+          description: State of the assigned task under execution.
+          title: Task State
+      - description: JWT Authorization Token
+        in: header
+        name: authorization
+        required: true
+        schema:
+          description: JWT Authorization Token
+          title: Authorization
+          type: string
+      responses:
+        '200':
+          content:
+            application/json:
+              schema:
+                title: Response State
+                type: object
+                nullable: true
+          description: Successful Response
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '422':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
+          description: Validation Error
+      summary: State
+      tags:
+      - Jobs
   /logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}:
     get:
       description: Elaborate the path and filename to expect from task 
execution.
@@ -464,6 +619,91 @@ components:
           title: Sysinfo
           type: object
       title: WorkerStateBody
+    WorkerQueuesBody:
+      description: Queues that a worker supports to run jobs on.
+      properties:
+        queues:
+          anyOf:
+          - items:
+              type: string
+            type: array
+          - type: object
+            nullable: true
+          description: List of queues the worker is pulling jobs from. If not 
provided,
+            worker pulls from all queues.
+          title: Queues
+        free_concurrency:
+          description: Number of free slots for running tasks.
+          title: Free Concurrency
+          type: integer
+      required:
+      - queues
+      - free_concurrency
+      title: WorkerQueuesBody
+      type: object
+    EdgeJobFetched:
+      description: Job that is to be executed on the edge worker.
+      properties:
+        command:
+          description: Command line to use to execute the job.
+          items:
+            type: string
+          title: Command
+          type: array
+        concurrency_slots:
+          description: Number of slots to use for the task.
+          title: Concurrency Slots
+          type: integer
+        dag_id:
+          description: Identifier of the DAG to which the task belongs.
+          title: Dag ID
+          type: string
+        map_index:
+          description: For dynamically mapped tasks the mapping number, -1 if 
the
+            task is not mapped.
+          title: Map Index
+          type: integer
+        run_id:
+          description: Run ID of the DAG execution.
+          title: Run ID
+          type: string
+        task_id:
+          description: Task name in the DAG.
+          title: Task ID
+          type: string
+        try_number:
+          description: The number of attempt to execute this task.
+          title: Try Number
+          type: integer
+      required:
+      - dag_id
+      - task_id
+      - run_id
+      - map_index
+      - try_number
+      - command
+      title: EdgeJobFetched
+      type: object
+    TaskInstanceState:
+      description: 'All possible states that a Task Instance can be in.
+
+
+        Note that None is also allowed, so always use this in a type hint with 
Optional.'
+      enum:
+      - removed
+      - scheduled
+      - queued
+      - running
+      - success
+      - restarting
+      - failed
+      - up_for_retry
+      - up_for_reschedule
+      - upstream_failed
+      - skipped
+      - deferred
+      title: TaskInstanceState
+      type: string
     PushLogsBody:
       description: Incremental new log content from worker.
       properties:
diff --git a/providers/src/airflow/providers/edge/provider.yaml 
b/providers/src/airflow/providers/edge/provider.yaml
index c4c289b228a..229f1ad68e4 100644
--- a/providers/src/airflow/providers/edge/provider.yaml
+++ b/providers/src/airflow/providers/edge/provider.yaml
@@ -27,7 +27,7 @@ source-date-epoch: 1729683247
 
 # note that those versions are maintained by release manager - do not update 
them manually
 versions:
-  - 0.8.1pre0
+  - 0.8.2pre0
 
 dependencies:
   - apache-airflow>=2.10.0
diff --git a/providers/src/airflow/providers/edge/worker_api/app.py 
b/providers/src/airflow/providers/edge/worker_api/app.py
index e90c5c47096..67325c8bdf3 100644
--- a/providers/src/airflow/providers/edge/worker_api/app.py
+++ b/providers/src/airflow/providers/edge/worker_api/app.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 from fastapi import FastAPI
 
 from airflow.providers.edge.worker_api.routes.health import health_router
+from airflow.providers.edge.worker_api.routes.jobs import jobs_router
 from airflow.providers.edge.worker_api.routes.logs import logs_router
 from airflow.providers.edge.worker_api.routes.worker import worker_router
 
@@ -36,6 +37,7 @@ def create_edge_worker_api_app() -> FastAPI:
     )
 
     edge_worker_api_app.include_router(health_router)
+    edge_worker_api_app.include_router(jobs_router)
     edge_worker_api_app.include_router(logs_router)
     edge_worker_api_app.include_router(worker_router)
     return edge_worker_api_app
diff --git a/providers/src/airflow/providers/edge/worker_api/datamodels.py 
b/providers/src/airflow/providers/edge/worker_api/datamodels.py
index f4455c7e1e2..fcfc47fc5bd 100644
--- a/providers/src/airflow/providers/edge/worker_api/datamodels.py
+++ b/providers/src/airflow/providers/edge/worker_api/datamodels.py
@@ -17,17 +17,14 @@
 from __future__ import annotations
 
 from datetime import datetime
-from typing import (  # noqa: UP035 - prevent pytest failing in back-compat
+from typing import (
     Annotated,
     Any,
-    Dict,
-    List,
-    Optional,
-    Union,
 )
 
 from pydantic import BaseModel, Field
 
+from airflow.models.taskinstancekey import TaskInstanceKey
 from airflow.providers.edge.models.edge_worker import EdgeWorkerState  # noqa: 
TCH001
 from airflow.providers.edge.worker_api.routes._v2_compat import Path
 
@@ -43,6 +40,7 @@ class WorkerApiDocs:
         title="Map Index",
         description="For dynamically mapped tasks the mapping number, -1 if 
the task is not mapped.",
     )
+    state = Path(title="Task State", description="State of the assigned task 
under execution.")
 
 
 class JsonRpcRequestBase(BaseModel):
@@ -59,29 +57,81 @@ class JsonRpcRequest(JsonRpcRequestBase):
 
     jsonrpc: Annotated[str, Field(description="JSON RPC Version", 
examples=["2.0"])]
     params: Annotated[
-        Optional[Dict[str, Any]],  # noqa: UP006, UP007 - prevent pytest 
failing in back-compat
+        dict[str, Any] | None,
         Field(description="Dictionary of parameters passed to the method."),
     ]
 
 
-class WorkerStateBody(BaseModel):
+class EdgeJobBase(BaseModel):
+    """Basic attributes of a job on the edge worker."""
+
+    dag_id: Annotated[
+        str, Field(title="Dag ID", description="Identifier of the DAG to which 
the task belongs.")
+    ]
+    task_id: Annotated[str, Field(title="Task ID", description="Task name in 
the DAG.")]
+    run_id: Annotated[str, Field(title="Run ID", description="Run ID of the 
DAG execution.")]
+    map_index: Annotated[
+        int,
+        Field(
+            title="Map Index",
+            description="For dynamically mapped tasks the mapping number, -1 
if the task is not mapped.",
+        ),
+    ]
+    try_number: Annotated[
+        int, Field(title="Try Number", description="The number of attempt to 
execute this task.")
+    ]
+
+    @property
+    def key(self) -> TaskInstanceKey:
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, 
self.try_number, self.map_index)
+
+
+class EdgeJobFetched(EdgeJobBase):
+    """Job that is to be executed on the edge worker."""
+
+    command: Annotated[
+        list[str], Field(title="Command", description="Command line to use to 
execute the job.")
+    ]
+    concurrency_slots: Annotated[int, Field(description="Number of concurrency 
slots the job requires.")]
+
+
+class WorkerQueuesBase(BaseModel):
+    """Queues that a worker supports to run jobs on."""
+
+    queues: Annotated[
+        list[str] | None,
+        Field(
+            None,
+            description="List of queues the worker is pulling jobs from. If 
not provided, worker pulls from all queues.",
+        ),
+    ]
+
+
+class WorkerQueuesBody(WorkerQueuesBase):
+    """Queues that a worker supports to run jobs on."""
+
+    free_concurrency: Annotated[int, Field(description="Number of free 
concurrency slots on the worker.")]
+
+
+class WorkerStateBody(WorkerQueuesBase):
     """Details of the worker state sent to the scheduler."""
 
     state: Annotated[EdgeWorkerState, Field(description="State of the worker 
from the view of the worker.")]
     jobs_active: Annotated[int, Field(description="Number of active jobs the 
worker is running.")] = 0
     queues: Annotated[
-        Optional[List[str]],  # noqa: UP006, UP007 - prevent pytest failing in 
back-compat
+        list[str] | None,
         Field(
             description="List of queues the worker is pulling jobs from. If 
not provided, worker pulls from all queues."
         ),
     ] = None
     sysinfo: Annotated[
-        Dict[str, Union[str, int]],  # noqa: UP006, UP007 - prevent pytest 
failing in back-compat
+        dict[str, str | int],
         Field(
             description="System information of the worker.",
             examples=[
                 {
                     "concurrency": 4,
+                    "free_concurrency": 3,
                     "airflow_version": "2.0.0",
                     "edge_provider_version": "1.0.0",
                 }
@@ -94,11 +144,11 @@ class WorkerQueueUpdateBody(BaseModel):
     """Changed queues for the worker."""
 
     new_queues: Annotated[
-        Optional[List[str]],  # noqa: UP006, UP007 - prevent pytest failing in 
back-compat
+        list[str] | None,
         Field(description="Additional queues to be added to worker."),
     ]
     remove_queues: Annotated[
-        Optional[List[str]],  # noqa: UP006, UP007 - prevent pytest failing in 
back-compat
+        list[str] | None,
         Field(description="Queues to remove from worker."),
     ]
 
diff --git 
a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py 
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index 767aef14b34..b00f5cf41b9 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -28,8 +28,15 @@ from flask import Response, request
 
 from airflow.exceptions import AirflowException
 from airflow.providers.edge.worker_api.auth import jwt_token_authorization, 
jwt_token_authorization_rpc
-from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest, 
PushLogsBody, WorkerStateBody
+from airflow.providers.edge.worker_api.datamodels import (
+    EdgeJobFetched,
+    JsonRpcRequest,
+    PushLogsBody,
+    WorkerQueuesBody,
+    WorkerStateBody,
+)
 from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException, 
status
+from airflow.providers.edge.worker_api.routes.jobs import fetch, state as 
state_api
 from airflow.providers.edge.worker_api.routes.logs import logfile_path, 
push_logs
 from airflow.providers.edge.worker_api.routes.worker import register, set_state
 from airflow.serialization.serialized_objects import BaseSerialization
@@ -37,6 +44,7 @@ from airflow.utils.session import NEW_SESSION, 
create_session, provide_session
 
 if TYPE_CHECKING:
     from airflow.api_connexion.types import APIResponse
+    from airflow.utils.state import TaskInstanceState
 
 
 log = logging.getLogger(__name__)
@@ -267,6 +275,44 @@ def set_state_v2(worker_name: str, body: dict[str, Any], 
session=NEW_SESSION) ->
         return e.to_response()  # type: ignore[attr-defined]
 
 
+@provide_session
+def job_fetch_v2(worker_name: str, body: dict[str, Any] | None = None, 
session=NEW_SESSION) -> Any:
+    """Handle Edge Worker API `/edge_worker/v1/jobs/fetch/{worker_name}` 
endpoint for Airflow 2.10."""
+    from flask import request
+
+    try:
+        auth = request.headers.get("Authorization", "")
+        jwt_token_authorization(request.path, auth)
+        queues = body["queues"] if body else None
+        free_concurrency = body["free_concurrency"] if body else 1
+        request_obj = WorkerQueuesBody(queues=queues, 
free_concurrency=free_concurrency)
+        job: EdgeJobFetched | None = fetch(worker_name, request_obj, session)
+        return job.model_dump() if job is not None else None
+    except HTTPException as e:
+        return e.to_response()  # type: ignore[attr-defined]
+
+
+@provide_session
+def job_state_v2(
+    dag_id: str,
+    task_id: str,
+    run_id: str,
+    try_number: int,
+    map_index: str,  # Note: Connexion can not have negative numbers in path 
parameters, use string therefore
+    state: TaskInstanceState,
+    session=NEW_SESSION,
+) -> Any:
+    """Handle Edge Worker API 
`/jobs/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}` 
endpoint for Airflow 2.10."""
+    from flask import request
+
+    try:
+        auth = request.headers.get("Authorization", "")
+        jwt_token_authorization(request.path, auth)
+        state_api(dag_id, task_id, run_id, try_number, int(map_index), state, 
session)
+    except HTTPException as e:
+        return e.to_response()  # type: ignore[attr-defined]
+
+
 def logfile_path_v2(
     dag_id: str,
     task_id: str,
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/jobs.py 
b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
new file mode 100644
index 00000000000..289fc3eed99
--- /dev/null
+++ b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from ast import literal_eval
+from typing import Annotated
+
+from sqlalchemy import select, update
+
+from airflow.providers.edge.models.edge_job import EdgeJobModel
+from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest
+from airflow.providers.edge.worker_api.datamodels import (
+    EdgeJobFetched,
+    WorkerApiDocs,
+    WorkerQueuesBody,
+)
+from airflow.providers.edge.worker_api.routes._v2_compat import (
+    AirflowRouter,
+    Body,
+    Depends,
+    SessionDep,
+    create_openapi_http_exception_doc,
+    status,
+)
+from airflow.utils import timezone
+from airflow.utils.sqlalchemy import with_row_locks
+from airflow.utils.state import TaskInstanceState
+
+jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
+
+
+@jobs_router.get(
+    "/fetch/{worker_name}",
+    dependencies=[Depends(jwt_token_authorization_rest)],
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_403_FORBIDDEN,
+        ]
+    ),
+)
+def fetch(
+    worker_name: str,
+    body: Annotated[
+        WorkerQueuesBody,
+        Body(
+            title="Log data chunks",
+            description="The queues and capacity from which the worker can 
fetch jobs.",
+        ),
+    ],
+    session: SessionDep,
+) -> EdgeJobFetched | None:
+    """Fetch a job to execute on the edge worker."""
+    query = (
+        select(EdgeJobModel)
+        .where(
+            EdgeJobModel.state == TaskInstanceState.QUEUED,
+            EdgeJobModel.concurrency_slots <= body.free_concurrency,
+        )
+        .order_by(EdgeJobModel.queued_dttm)
+    )
+    if body.queues:
+        query = query.where(EdgeJobModel.queue.in_(body.queues))
+    query = query.limit(1)
+    query = with_row_locks(query, of=EdgeJobModel, session=session, 
skip_locked=True)
+    job: EdgeJobModel = session.scalar(query)
+    if not job:
+        return None
+    job.state = TaskInstanceState.RUNNING
+    job.edge_worker = worker_name
+    job.last_update = timezone.utcnow()
+    session.commit()
+    return EdgeJobFetched(
+        dag_id=job.dag_id,
+        task_id=job.task_id,
+        run_id=job.run_id,
+        map_index=job.map_index,
+        try_number=job.try_number,
+        command=literal_eval(job.command),
+        concurrency_slots=job.concurrency_slots,
+    )
+
+
+@jobs_router.patch(
+    "/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}",
+    dependencies=[Depends(jwt_token_authorization_rest)],
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_403_FORBIDDEN,
+        ]
+    ),
+)
+def state(
+    dag_id: Annotated[str, WorkerApiDocs.dag_id],
+    task_id: Annotated[str, WorkerApiDocs.task_id],
+    run_id: Annotated[str, WorkerApiDocs.run_id],
+    try_number: Annotated[int, WorkerApiDocs.try_number],
+    map_index: Annotated[int, WorkerApiDocs.map_index],
+    state: Annotated[TaskInstanceState, WorkerApiDocs.state],
+    session: SessionDep,
+) -> None:
+    """Update the state of a job running on the edge worker."""
+    query = (
+        update(EdgeJobModel)
+        .where(
+            EdgeJobModel.dag_id == dag_id,
+            EdgeJobModel.task_id == task_id,
+            EdgeJobModel.run_id == run_id,
+            EdgeJobModel.map_index == map_index,
+            EdgeJobModel.try_number == try_number,
+        )
+        .values(state=state, last_update=timezone.utcnow())
+    )
+    session.execute(query)
diff --git a/providers/tests/edge/cli/test_edge_command.py 
b/providers/tests/edge/cli/test_edge_command.py
index 4f2706ef531..123b06af3f9 100644
--- a/providers/tests/edge/cli/test_edge_command.py
+++ b/providers/tests/edge/cli/test_edge_command.py
@@ -28,8 +28,8 @@ import time_machine
 
 from airflow.exceptions import AirflowException
 from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job, 
_write_pid_to_pidfile
-from airflow.providers.edge.models.edge_job import EdgeJob
 from airflow.providers.edge.models.edge_worker import EdgeWorkerState, 
EdgeWorkerVersionException
+from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched
 from airflow.utils import timezone
 from airflow.utils.state import TaskInstanceState
 
@@ -95,19 +95,14 @@ class TestEdgeWorkerCli:
 
         return [
             _Job(
-                edge_job=EdgeJob(
+                edge_job=EdgeJobFetched(
                     dag_id="test",
                     task_id="test1",
                     run_id="test",
                     map_index=-1,
                     try_number=1,
-                    state=TaskInstanceState.RUNNING,
-                    queue="test",
                     concurrency_slots=1,
                     command=["test", "command"],
-                    queued_dttm=datetime.now(),
-                    edge_worker=None,
-                    last_update=None,
                 ),
                 process=_MockPopen(),
                 logfile=logfile,
@@ -126,19 +121,14 @@ class TestEdgeWorkerCli:
         [
             pytest.param(None, False, (0, 0), id="no_job"),
             pytest.param(
-                EdgeJob(
+                EdgeJobFetched(
                     dag_id="test",
                     task_id="test",
                     run_id="test",
                     map_index=-1,
                     try_number=1,
-                    state=TaskInstanceState.QUEUED,
-                    queue="test",
                     concurrency_slots=1,
                     command=["test", "command"],
-                    queued_dttm=datetime.now(),
-                    edge_worker=None,
-                    last_update=None,
                 ),
                 True,
                 (1, 1),
@@ -146,9 +136,9 @@ class TestEdgeWorkerCli:
             ),
         ],
     )
-    @patch("airflow.providers.edge.models.edge_job.EdgeJob.reserve_task")
+    @patch("airflow.providers.edge.cli.edge_command.jobs_fetch")
     @patch("airflow.providers.edge.cli.edge_command.logs_logfile_path")
-    @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+    @patch("airflow.providers.edge.cli.edge_command.jobs_set_state")
     @patch("subprocess.Popen")
     def test_fetch_job(
         self,
@@ -181,7 +171,7 @@ class TestEdgeWorkerCli:
             == worker_with_job.concurrency - 
worker_with_job.jobs[0].edge_job.concurrency_slots
         )
 
-    @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+    @patch("airflow.providers.edge.cli.edge_command.jobs_set_state")
     def test_check_running_jobs_success(self, mock_set_state, worker_with_job: 
_EdgeWorkerCli):
         job = worker_with_job.jobs[0]
         job.process.generated_returncode = 0  # type: ignore[attr-defined]
@@ -191,7 +181,7 @@ class TestEdgeWorkerCli:
         mock_set_state.assert_called_once_with(job.edge_job.key, 
TaskInstanceState.SUCCESS)
         assert worker_with_job.free_concurrency == worker_with_job.concurrency
 
-    @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+    @patch("airflow.providers.edge.cli.edge_command.jobs_set_state")
     def test_check_running_jobs_failed(self, mock_set_state, worker_with_job: 
_EdgeWorkerCli):
         job = worker_with_job.jobs[0]
         job.process.generated_returncode = 42  # type: ignore[attr-defined]


Reply via email to