This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 161beebc771 Migrate Edge calls for Worker to FastAPI part 3 - Jobs
routes (#44433)
161beebc771 is described below
commit 161beebc771329ad0525f4df39b46c6f72776034
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun Dec 1 12:01:29 2024 +0100
Migrate Edge calls for Worker to FastAPI part 3 - Jobs routes (#44433)
* Migrate Edge calls for Worker to FastAPI 3 - Jobs route
* Review Feedback
* Remove outdated type hints from review feedback
* Update providers/src/airflow/providers/edge/worker_api/routes/jobs.py
Co-authored-by: Copilot <[email protected]>
* Add missing filter for free concurrency
---------
Co-authored-by: Copilot <[email protected]>
---
providers/src/airflow/providers/edge/CHANGELOG.rst | 8 +
providers/src/airflow/providers/edge/__init__.py | 2 +-
.../src/airflow/providers/edge/cli/api_client.py | 34 ++-
.../src/airflow/providers/edge/cli/edge_command.py | 21 +-
.../providers/edge/openapi/edge_worker_api_v1.yaml | 240 +++++++++++++++++++++
providers/src/airflow/providers/edge/provider.yaml | 2 +-
.../src/airflow/providers/edge/worker_api/app.py | 2 +
.../providers/edge/worker_api/datamodels.py | 72 ++++++-
.../providers/edge/worker_api/routes/_v2_routes.py | 48 ++++-
.../providers/edge/worker_api/routes/jobs.py | 130 +++++++++++
providers/tests/edge/cli/test_edge_command.py | 24 +--
11 files changed, 541 insertions(+), 42 deletions(-)
diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst
b/providers/src/airflow/providers/edge/CHANGELOG.rst
index 93900dfeb50..05c3d728246 100644
--- a/providers/src/airflow/providers/edge/CHANGELOG.rst
+++ b/providers/src/airflow/providers/edge/CHANGELOG.rst
@@ -27,6 +27,14 @@
Changelog
---------
+0.8.2pre0
+.........
+
+Misc
+~~~~
+
+* ``Migrate worker job calls to FastAPI.``
+
0.8.1pre0
.........
diff --git a/providers/src/airflow/providers/edge/__init__.py
b/providers/src/airflow/providers/edge/__init__.py
index 8c53d0bed1b..d826c633ead 100644
--- a/providers/src/airflow/providers/edge/__init__.py
+++ b/providers/src/airflow/providers/edge/__init__.py
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
__all__ = ["__version__"]
-__version__ = "0.8.1pre0"
+__version__ = "0.8.2pre0"
if
packaging.version.parse(packaging.version.parse(airflow_version).base_version)
< packaging.version.parse(
"2.10.0"
diff --git a/providers/src/airflow/providers/edge/cli/api_client.py
b/providers/src/airflow/providers/edge/cli/api_client.py
index 9b5781e359d..c0a0144f5fe 100644
--- a/providers/src/airflow/providers/edge/cli/api_client.py
+++ b/providers/src/airflow/providers/edge/cli/api_client.py
@@ -32,7 +32,13 @@ from urllib3.exceptions import NewConnectionError
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.edge.worker_api.auth import jwt_signer
-from airflow.providers.edge.worker_api.datamodels import PushLogsBody,
WorkerStateBody
+from airflow.providers.edge.worker_api.datamodels import (
+ EdgeJobFetched,
+ PushLogsBody,
+ WorkerQueuesBody,
+ WorkerStateBody,
+)
+from airflow.utils.state import TaskInstanceState # noqa: TC001
if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -114,6 +120,28 @@ def worker_set_state(
)
+def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int)
-> EdgeJobFetched | None:
+ """Fetch a job to execute on the edge worker."""
+ result = _make_generic_request(
+ "GET",
+ f"jobs/fetch/{quote(hostname)}",
+ WorkerQueuesBody(queues=queues,
free_concurrency=free_concurrency).model_dump_json(
+ exclude_unset=True
+ ),
+ )
+ if result:
+ return EdgeJobFetched(**result)
+ return None
+
+
+def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
+ """Set the state of a job."""
+ _make_generic_request(
+ "PATCH",
+
f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
+ )
+
+
def logs_logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
result = _make_generic_request(
@@ -133,5 +161,7 @@ def logs_push(
_make_generic_request(
"POST",
f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
- PushLogsBody(log_chunk_time=log_chunk_time,
log_chunk_data=log_chunk_data).model_dump_json(),
+ PushLogsBody(log_chunk_time=log_chunk_time,
log_chunk_data=log_chunk_data).model_dump_json(
+ exclude_unset=True
+ ),
)
diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py
b/providers/src/airflow/providers/edge/cli/edge_command.py
index f175d6e77b3..d93a2269973 100644
--- a/providers/src/airflow/providers/edge/cli/edge_command.py
+++ b/providers/src/airflow/providers/edge/cli/edge_command.py
@@ -26,6 +26,7 @@ from datetime import datetime
from pathlib import Path
from subprocess import Popen
from time import sleep
+from typing import TYPE_CHECKING
import psutil
from lockfile.pidlockfile import read_pid_from_pidfile,
remove_existing_pidfile, write_pid_to_pidfile
@@ -37,18 +38,22 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.edge import __version__ as edge_provider_version
from airflow.providers.edge.cli.api_client import (
+ jobs_fetch,
+ jobs_set_state,
logs_logfile_path,
logs_push,
worker_register,
worker_set_state,
)
-from airflow.providers.edge.models.edge_job import EdgeJob
from airflow.providers.edge.models.edge_worker import EdgeWorkerState,
EdgeWorkerVersionException
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.platform import IS_WINDOWS
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.state import TaskInstanceState
+if TYPE_CHECKING:
+ from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched
+
logger = logging.getLogger(__name__)
EDGE_WORKER_PROCESS_NAME = "edge-worker"
EDGE_WORKER_HEADER = "\n".join(
@@ -81,7 +86,7 @@ def force_use_internal_api_on_edge_worker():
if AIRFLOW_V_3_0_PLUS:
# Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
raise SystemExit(
- "Error: EdgeWorker is currently broken on AIrflow 3/main due
to removal of AIP-44, rework for AIP-72."
+ "Error: EdgeWorker is currently broken on Airflow 3/main due
to removal of AIP-44, rework for AIP-72."
)
api_url = conf.get("edge", "api_url")
@@ -141,7 +146,7 @@ def _write_pid_to_pidfile(pid_file_path: str):
class _Job:
"""Holds all information for a task/job to be executed as bundle."""
- edge_job: EdgeJob
+ edge_job: EdgeJobFetched
process: Popen
logfile: Path
logsize: int
@@ -240,9 +245,7 @@ class _EdgeWorkerCli:
def fetch_job(self) -> bool:
"""Fetch and start a new job from central site."""
logger.debug("Attempting to fetch a new job...")
- edge_job = EdgeJob.reserve_task(
- worker_name=self.hostname, free_concurrency=self.free_concurrency,
queues=self.queues
- )
+ edge_job = jobs_fetch(self.hostname, self.queues,
self.free_concurrency)
if edge_job:
logger.info("Received job: %s", edge_job)
env = os.environ.copy()
@@ -252,7 +255,7 @@ class _EdgeWorkerCli:
process = Popen(edge_job.command, close_fds=True, env=env,
start_new_session=True)
logfile = logs_logfile_path(edge_job.key)
self.jobs.append(_Job(edge_job, process, logfile, 0))
- EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING)
+ jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
return True
logger.info("No new job to process%s", f", {len(self.jobs)} still
running" if self.jobs else "")
@@ -268,10 +271,10 @@ class _EdgeWorkerCli:
self.jobs.remove(job)
if job.process.returncode == 0:
logger.info("Job completed: %s", job.edge_job)
- EdgeJob.set_state(job.edge_job.key,
TaskInstanceState.SUCCESS)
+ jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
else:
logger.error("Job failed: %s", job.edge_job)
- EdgeJob.set_state(job.edge_job.key,
TaskInstanceState.FAILED)
+ jobs_set_state(job.edge_job.key, TaskInstanceState.FAILED)
else:
used_concurrency += job.edge_job.concurrency_slots
diff --git
a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
index 7915bdb5b4a..ef1f24a2288 100644
--- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
+++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
@@ -178,6 +178,161 @@ paths:
summary: Register
tags:
- Worker
+ /jobs/fetch/{worker_name}:
+ get:
+ description: Fetch a job to execute on the edge worker.
+ x-openapi-router-controller:
airflow.providers.edge.worker_api.routes._v2_routes
+ operationId: job_fetch_v2
+ parameters:
+ - in: path
+ name: worker_name
+ required: true
+ schema:
+ title: Worker Name
+ type: string
+ - description: JWT Authorization Token
+ in: header
+ name: authorization
+ required: true
+ schema:
+ description: JWT Authorization Token
+ title: Authorization
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/WorkerQueuesBody'
+ description: The worker remote has no access to log sink and
with this
+ can send log chunks to the central site.
+ title: Log data chunks
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ anyOf:
+ - $ref: '#/components/schemas/EdgeJobFetched'
+ - type: object
+ nullable: true
+ title: Response Fetch
+ description: Successful Response
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ description: Validation Error
+ summary: Fetch
+ tags:
+ - Jobs
+ /jobs/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}:
+ patch:
+ description: Update the state of a job running on the edge worker.
+ x-openapi-router-controller:
airflow.providers.edge.worker_api.routes._v2_routes
+ operationId: job_state_v2
+ parameters:
+ - description: Identifier of the DAG to which the task belongs.
+ in: path
+ name: dag_id
+ required: true
+ schema:
+ description: Identifier of the DAG to which the task belongs.
+ title: Dag ID
+ type: string
+ - description: Task name in the DAG.
+ in: path
+ name: task_id
+ required: true
+ schema:
+ description: Task name in the DAG.
+ title: Task ID
+ type: string
+ - description: Run ID of the DAG execution.
+ in: path
+ name: run_id
+ required: true
+ schema:
+ description: Run ID of the DAG execution.
+ title: Run ID
+ type: string
+ - description: The number of attempt to execute this task.
+ in: path
+ name: try_number
+ required: true
+ schema:
+ description: The number of attempt to execute this task.
+ title: Try Number
+ type: integer
+ - description: For dynamically mapped tasks the mapping number, -1 if
the task
+ is not mapped.
+ in: path
+ name: map_index
+ required: true
+ schema:
+ description: For dynamically mapped tasks the mapping number, -1 if
the
+ task is not mapped.
+ title: Map Index
+ type: string # This should be integer, but Connexion/Flask do not
support negative integers in path parameters
+ - description: State of the assigned task under execution.
+ in: path
+ name: state
+ required: true
+ schema:
+ $ref: '#/components/schemas/TaskInstanceState'
+ description: State of the assigned task under execution.
+ title: Task State
+ - description: JWT Authorization Token
+ in: header
+ name: authorization
+ required: true
+ schema:
+ description: JWT Authorization Token
+ title: Authorization
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ title: Response State
+ type: object
+ nullable: true
+ description: Successful Response
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ description: Validation Error
+ summary: State
+ tags:
+ - Jobs
/logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}:
get:
description: Elaborate the path and filename to expect from task
execution.
@@ -464,6 +619,91 @@ components:
title: Sysinfo
type: object
title: WorkerStateBody
+ WorkerQueuesBody:
+ description: Queues that a worker supports to run jobs on.
+ properties:
+ queues:
+ anyOf:
+ - items:
+ type: string
+ type: array
+ - type: object
+ nullable: true
+ description: List of queues the worker is pulling jobs from. If not
provided,
+ worker pulls from all queues.
+ title: Queues
+ free_concurrency:
+ description: Number of free slots for running tasks.
+ title: Free Concurrency
+ type: integer
+ required:
+ - queues
+ - free_concurrency
+ title: WorkerQueuesBody
+ type: object
+ EdgeJobFetched:
+ description: Job that is to be executed on the edge worker.
+ properties:
+ command:
+ description: Command line to use to execute the job.
+ items:
+ type: string
+ title: Command
+ type: array
+ concurrency_slots:
+ description: Number of slots to use for the task.
+ title: Concurrency Slots
+ type: integer
+ dag_id:
+ description: Identifier of the DAG to which the task belongs.
+ title: Dag ID
+ type: string
+ map_index:
+ description: For dynamically mapped tasks the mapping number, -1 if
the
+ task is not mapped.
+ title: Map Index
+ type: integer
+ run_id:
+ description: Run ID of the DAG execution.
+ title: Run ID
+ type: string
+ task_id:
+ description: Task name in the DAG.
+ title: Task ID
+ type: string
+ try_number:
+ description: The number of attempt to execute this task.
+ title: Try Number
+ type: integer
+ required:
+ - dag_id
+ - task_id
+ - run_id
+ - map_index
+ - try_number
+ - command
+ title: EdgeJobFetched
+ type: object
+ TaskInstanceState:
+ description: 'All possible states that a Task Instance can be in.
+
+
+ Note that None is also allowed, so always use this in a type hint with
Optional.'
+ enum:
+ - removed
+ - scheduled
+ - queued
+ - running
+ - success
+ - restarting
+ - failed
+ - up_for_retry
+ - up_for_reschedule
+ - upstream_failed
+ - skipped
+ - deferred
+ title: TaskInstanceState
+ type: string
PushLogsBody:
description: Incremental new log content from worker.
properties:
diff --git a/providers/src/airflow/providers/edge/provider.yaml
b/providers/src/airflow/providers/edge/provider.yaml
index c4c289b228a..229f1ad68e4 100644
--- a/providers/src/airflow/providers/edge/provider.yaml
+++ b/providers/src/airflow/providers/edge/provider.yaml
@@ -27,7 +27,7 @@ source-date-epoch: 1729683247
# note that those versions are maintained by release manager - do not update
them manually
versions:
- - 0.8.1pre0
+ - 0.8.2pre0
dependencies:
- apache-airflow>=2.10.0
diff --git a/providers/src/airflow/providers/edge/worker_api/app.py
b/providers/src/airflow/providers/edge/worker_api/app.py
index e90c5c47096..67325c8bdf3 100644
--- a/providers/src/airflow/providers/edge/worker_api/app.py
+++ b/providers/src/airflow/providers/edge/worker_api/app.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from fastapi import FastAPI
from airflow.providers.edge.worker_api.routes.health import health_router
+from airflow.providers.edge.worker_api.routes.jobs import jobs_router
from airflow.providers.edge.worker_api.routes.logs import logs_router
from airflow.providers.edge.worker_api.routes.worker import worker_router
@@ -36,6 +37,7 @@ def create_edge_worker_api_app() -> FastAPI:
)
edge_worker_api_app.include_router(health_router)
+ edge_worker_api_app.include_router(jobs_router)
edge_worker_api_app.include_router(logs_router)
edge_worker_api_app.include_router(worker_router)
return edge_worker_api_app
diff --git a/providers/src/airflow/providers/edge/worker_api/datamodels.py
b/providers/src/airflow/providers/edge/worker_api/datamodels.py
index f4455c7e1e2..fcfc47fc5bd 100644
--- a/providers/src/airflow/providers/edge/worker_api/datamodels.py
+++ b/providers/src/airflow/providers/edge/worker_api/datamodels.py
@@ -17,17 +17,14 @@
from __future__ import annotations
from datetime import datetime
-from typing import ( # noqa: UP035 - prevent pytest failing in back-compat
+from typing import (
Annotated,
Any,
- Dict,
- List,
- Optional,
- Union,
)
from pydantic import BaseModel, Field
+from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.edge.models.edge_worker import EdgeWorkerState # noqa:
TCH001
from airflow.providers.edge.worker_api.routes._v2_compat import Path
@@ -43,6 +40,7 @@ class WorkerApiDocs:
title="Map Index",
description="For dynamically mapped tasks the mapping number, -1 if
the task is not mapped.",
)
+ state = Path(title="Task State", description="State of the assigned task
under execution.")
class JsonRpcRequestBase(BaseModel):
@@ -59,29 +57,81 @@ class JsonRpcRequest(JsonRpcRequestBase):
jsonrpc: Annotated[str, Field(description="JSON RPC Version",
examples=["2.0"])]
params: Annotated[
- Optional[Dict[str, Any]], # noqa: UP006, UP007 - prevent pytest
failing in back-compat
+ dict[str, Any] | None,
Field(description="Dictionary of parameters passed to the method."),
]
-class WorkerStateBody(BaseModel):
+class EdgeJobBase(BaseModel):
+ """Basic attributes of a job on the edge worker."""
+
+ dag_id: Annotated[
+ str, Field(title="Dag ID", description="Identifier of the DAG to which
the task belongs.")
+ ]
+ task_id: Annotated[str, Field(title="Task ID", description="Task name in
the DAG.")]
+ run_id: Annotated[str, Field(title="Run ID", description="Run ID of the
DAG execution.")]
+ map_index: Annotated[
+ int,
+ Field(
+ title="Map Index",
+ description="For dynamically mapped tasks the mapping number, -1
if the task is not mapped.",
+ ),
+ ]
+ try_number: Annotated[
+ int, Field(title="Try Number", description="The number of attempt to
execute this task.")
+ ]
+
+ @property
+ def key(self) -> TaskInstanceKey:
+ return TaskInstanceKey(self.dag_id, self.task_id, self.run_id,
self.try_number, self.map_index)
+
+
+class EdgeJobFetched(EdgeJobBase):
+ """Job that is to be executed on the edge worker."""
+
+ command: Annotated[
+ list[str], Field(title="Command", description="Command line to use to
execute the job.")
+ ]
+ concurrency_slots: Annotated[int, Field(description="Number of concurrency
slots the job requires.")]
+
+
+class WorkerQueuesBase(BaseModel):
+ """Queues that a worker supports to run jobs on."""
+
+ queues: Annotated[
+ list[str] | None,
+ Field(
+ None,
+ description="List of queues the worker is pulling jobs from. If
not provided, worker pulls from all queues.",
+ ),
+ ]
+
+
+class WorkerQueuesBody(WorkerQueuesBase):
+ """Queues that a worker supports to run jobs on."""
+
+ free_concurrency: Annotated[int, Field(description="Number of free
concurrency slots on the worker.")]
+
+
+class WorkerStateBody(WorkerQueuesBase):
"""Details of the worker state sent to the scheduler."""
state: Annotated[EdgeWorkerState, Field(description="State of the worker
from the view of the worker.")]
jobs_active: Annotated[int, Field(description="Number of active jobs the
worker is running.")] = 0
queues: Annotated[
- Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in
back-compat
+ list[str] | None,
Field(
description="List of queues the worker is pulling jobs from. If
not provided, worker pulls from all queues."
),
] = None
sysinfo: Annotated[
- Dict[str, Union[str, int]], # noqa: UP006, UP007 - prevent pytest
failing in back-compat
+ dict[str, str | int],
Field(
description="System information of the worker.",
examples=[
{
"concurrency": 4,
+ "free_concurrency": 3,
"airflow_version": "2.0.0",
"edge_provider_version": "1.0.0",
}
@@ -94,11 +144,11 @@ class WorkerQueueUpdateBody(BaseModel):
"""Changed queues for the worker."""
new_queues: Annotated[
- Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in
back-compat
+ list[str] | None,
Field(description="Additional queues to be added to worker."),
]
remove_queues: Annotated[
- Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in
back-compat
+ list[str] | None,
Field(description="Queues to remove from worker."),
]
diff --git
a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index 767aef14b34..b00f5cf41b9 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -28,8 +28,15 @@ from flask import Response, request
from airflow.exceptions import AirflowException
from airflow.providers.edge.worker_api.auth import jwt_token_authorization,
jwt_token_authorization_rpc
-from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest,
PushLogsBody, WorkerStateBody
+from airflow.providers.edge.worker_api.datamodels import (
+ EdgeJobFetched,
+ JsonRpcRequest,
+ PushLogsBody,
+ WorkerQueuesBody,
+ WorkerStateBody,
+)
from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException,
status
+from airflow.providers.edge.worker_api.routes.jobs import fetch, state as
state_api
from airflow.providers.edge.worker_api.routes.logs import logfile_path,
push_logs
from airflow.providers.edge.worker_api.routes.worker import register, set_state
from airflow.serialization.serialized_objects import BaseSerialization
@@ -37,6 +44,7 @@ from airflow.utils.session import NEW_SESSION,
create_session, provide_session
if TYPE_CHECKING:
from airflow.api_connexion.types import APIResponse
+ from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
@@ -267,6 +275,44 @@ def set_state_v2(worker_name: str, body: dict[str, Any],
session=NEW_SESSION) ->
return e.to_response() # type: ignore[attr-defined]
+@provide_session
+def job_fetch_v2(worker_name: str, body: dict[str, Any] | None = None,
session=NEW_SESSION) -> Any:
+ """Handle Edge Worker API `/edge_worker/v1/jobs/fetch/{worker_name}`
endpoint for Airflow 2.10."""
+ from flask import request
+
+ try:
+ auth = request.headers.get("Authorization", "")
+ jwt_token_authorization(request.path, auth)
+ queues = body["queues"] if body else None
+ free_concurrency = body["free_concurrency"] if body else 1
+ request_obj = WorkerQueuesBody(queues=queues,
free_concurrency=free_concurrency)
+ job: EdgeJobFetched | None = fetch(worker_name, request_obj, session)
+ return job.model_dump() if job is not None else None
+ except HTTPException as e:
+ return e.to_response() # type: ignore[attr-defined]
+
+
+@provide_session
+def job_state_v2(
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ try_number: int,
+ map_index: str, # Note: Connexion can not have negative numbers in path
parameters, use string therefore
+ state: TaskInstanceState,
+ session=NEW_SESSION,
+) -> Any:
+ """Handle Edge Worker API
`/jobs/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}`
endpoint for Airflow 2.10."""
+ from flask import request
+
+ try:
+ auth = request.headers.get("Authorization", "")
+ jwt_token_authorization(request.path, auth)
+ state_api(dag_id, task_id, run_id, try_number, int(map_index), state,
session)
+ except HTTPException as e:
+ return e.to_response() # type: ignore[attr-defined]
+
+
def logfile_path_v2(
dag_id: str,
task_id: str,
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
new file mode 100644
index 00000000000..289fc3eed99
--- /dev/null
+++ b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from ast import literal_eval
+from typing import Annotated
+
+from sqlalchemy import select, update
+
+from airflow.providers.edge.models.edge_job import EdgeJobModel
+from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest
+from airflow.providers.edge.worker_api.datamodels import (
+ EdgeJobFetched,
+ WorkerApiDocs,
+ WorkerQueuesBody,
+)
+from airflow.providers.edge.worker_api.routes._v2_compat import (
+ AirflowRouter,
+ Body,
+ Depends,
+ SessionDep,
+ create_openapi_http_exception_doc,
+ status,
+)
+from airflow.utils import timezone
+from airflow.utils.sqlalchemy import with_row_locks
+from airflow.utils.state import TaskInstanceState
+
+jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
+
+
+@jobs_router.get(
+ "/fetch/{worker_name}",
+ dependencies=[Depends(jwt_token_authorization_rest)],
+ responses=create_openapi_http_exception_doc(
+ [
+ status.HTTP_400_BAD_REQUEST,
+ status.HTTP_403_FORBIDDEN,
+ ]
+ ),
+)
+def fetch(
+ worker_name: str,
+ body: Annotated[
+ WorkerQueuesBody,
+ Body(
+ title="Log data chunks",
+ description="The queues and capacity from which the worker can
fetch jobs.",
+ ),
+ ],
+ session: SessionDep,
+) -> EdgeJobFetched | None:
+ """Fetch a job to execute on the edge worker."""
+ query = (
+ select(EdgeJobModel)
+ .where(
+ EdgeJobModel.state == TaskInstanceState.QUEUED,
+ EdgeJobModel.concurrency_slots <= body.free_concurrency,
+ )
+ .order_by(EdgeJobModel.queued_dttm)
+ )
+ if body.queues:
+ query = query.where(EdgeJobModel.queue.in_(body.queues))
+ query = query.limit(1)
+ query = with_row_locks(query, of=EdgeJobModel, session=session,
skip_locked=True)
+ job: EdgeJobModel = session.scalar(query)
+ if not job:
+ return None
+ job.state = TaskInstanceState.RUNNING
+ job.edge_worker = worker_name
+ job.last_update = timezone.utcnow()
+ session.commit()
+ return EdgeJobFetched(
+ dag_id=job.dag_id,
+ task_id=job.task_id,
+ run_id=job.run_id,
+ map_index=job.map_index,
+ try_number=job.try_number,
+ command=literal_eval(job.command),
+ concurrency_slots=job.concurrency_slots,
+ )
+
+
+@jobs_router.patch(
+ "/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}",
+ dependencies=[Depends(jwt_token_authorization_rest)],
+ responses=create_openapi_http_exception_doc(
+ [
+ status.HTTP_400_BAD_REQUEST,
+ status.HTTP_403_FORBIDDEN,
+ ]
+ ),
+)
+def state(
+ dag_id: Annotated[str, WorkerApiDocs.dag_id],
+ task_id: Annotated[str, WorkerApiDocs.task_id],
+ run_id: Annotated[str, WorkerApiDocs.run_id],
+ try_number: Annotated[int, WorkerApiDocs.try_number],
+ map_index: Annotated[int, WorkerApiDocs.map_index],
+ state: Annotated[TaskInstanceState, WorkerApiDocs.state],
+ session: SessionDep,
+) -> None:
+ """Update the state of a job running on the edge worker."""
+ query = (
+ update(EdgeJobModel)
+ .where(
+ EdgeJobModel.dag_id == dag_id,
+ EdgeJobModel.task_id == task_id,
+ EdgeJobModel.run_id == run_id,
+ EdgeJobModel.map_index == map_index,
+ EdgeJobModel.try_number == try_number,
+ )
+ .values(state=state, last_update=timezone.utcnow())
+ )
+ session.execute(query)
diff --git a/providers/tests/edge/cli/test_edge_command.py
b/providers/tests/edge/cli/test_edge_command.py
index 4f2706ef531..123b06af3f9 100644
--- a/providers/tests/edge/cli/test_edge_command.py
+++ b/providers/tests/edge/cli/test_edge_command.py
@@ -28,8 +28,8 @@ import time_machine
from airflow.exceptions import AirflowException
from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job,
_write_pid_to_pidfile
-from airflow.providers.edge.models.edge_job import EdgeJob
from airflow.providers.edge.models.edge_worker import EdgeWorkerState,
EdgeWorkerVersionException
+from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched
from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState
@@ -95,19 +95,14 @@ class TestEdgeWorkerCli:
return [
_Job(
- edge_job=EdgeJob(
+ edge_job=EdgeJobFetched(
dag_id="test",
task_id="test1",
run_id="test",
map_index=-1,
try_number=1,
- state=TaskInstanceState.RUNNING,
- queue="test",
concurrency_slots=1,
command=["test", "command"],
- queued_dttm=datetime.now(),
- edge_worker=None,
- last_update=None,
),
process=_MockPopen(),
logfile=logfile,
@@ -126,19 +121,14 @@ class TestEdgeWorkerCli:
[
pytest.param(None, False, (0, 0), id="no_job"),
pytest.param(
- EdgeJob(
+ EdgeJobFetched(
dag_id="test",
task_id="test",
run_id="test",
map_index=-1,
try_number=1,
- state=TaskInstanceState.QUEUED,
- queue="test",
concurrency_slots=1,
command=["test", "command"],
- queued_dttm=datetime.now(),
- edge_worker=None,
- last_update=None,
),
True,
(1, 1),
@@ -146,9 +136,9 @@ class TestEdgeWorkerCli:
),
],
)
- @patch("airflow.providers.edge.models.edge_job.EdgeJob.reserve_task")
+ @patch("airflow.providers.edge.cli.edge_command.jobs_fetch")
@patch("airflow.providers.edge.cli.edge_command.logs_logfile_path")
- @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+ @patch("airflow.providers.edge.cli.edge_command.jobs_set_state")
@patch("subprocess.Popen")
def test_fetch_job(
self,
@@ -181,7 +171,7 @@ class TestEdgeWorkerCli:
== worker_with_job.concurrency -
worker_with_job.jobs[0].edge_job.concurrency_slots
)
- @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+ @patch("airflow.providers.edge.cli.edge_command.jobs_set_state")
def test_check_running_jobs_success(self, mock_set_state, worker_with_job:
_EdgeWorkerCli):
job = worker_with_job.jobs[0]
job.process.generated_returncode = 0 # type: ignore[attr-defined]
@@ -191,7 +181,7 @@ class TestEdgeWorkerCli:
mock_set_state.assert_called_once_with(job.edge_job.key,
TaskInstanceState.SUCCESS)
assert worker_with_job.free_concurrency == worker_with_job.concurrency
- @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+ @patch("airflow.providers.edge.cli.edge_command.jobs_set_state")
def test_check_running_jobs_failed(self, mock_set_state, worker_with_job:
_EdgeWorkerCli):
job = worker_with_job.jobs[0]
job.process.generated_returncode = 42 # type: ignore[attr-defined]