This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6057a2e04e2 Migrate Edge calls for Worker to FastAPI part 1 - Worker
routes (#44311)
6057a2e04e2 is described below
commit 6057a2e04e2488681f0874d236f26385c084a7ac
Author: Jens Scheffler <[email protected]>
AuthorDate: Sat Nov 30 22:48:09 2024 +0100
Migrate Edge calls for Worker to FastAPI part 1 - Worker routes (#44311)
---
docs/spelling_wordlist.txt | 1 +
providers/src/airflow/providers/edge/CHANGELOG.rst | 9 +
providers/src/airflow/providers/edge/__init__.py | 2 +-
.../src/airflow/providers/edge/cli/api_client.py | 114 ++++++++++
.../src/airflow/providers/edge/cli/edge_command.py | 53 ++++-
.../providers/edge/executors/edge_executor.py | 4 +-
.../airflow/providers/edge/models/edge_worker.py | 200 ++++++-----------
.../providers/edge/openapi/edge_worker_api_v1.yaml | 236 +++++++++++++++++++-
providers/src/airflow/providers/edge/provider.yaml | 2 +-
.../src/airflow/providers/edge/worker_api/app.py | 4 +-
.../src/airflow/providers/edge/worker_api/auth.py | 110 ++++++++++
.../providers/edge/worker_api/datamodels.py | 74 ++++++-
.../providers/edge/worker_api/routes/_v2_compat.py | 25 ++-
.../routes/{rpc_api.py => _v2_routes.py} | 244 +++++++++------------
.../providers/edge/worker_api/routes/worker.py | 178 +++++++++++++++
providers/tests/edge/cli/test_edge_command.py | 93 ++++----
.../routes/test_worker.py} | 70 +++---
17 files changed, 1027 insertions(+), 392 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index fa8ffb4a2c3..debbd43c896 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1619,6 +1619,7 @@ symlinking
symlinks
sync'ed
sys
+sysinfo
syspath
Systemd
systemd
diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst
b/providers/src/airflow/providers/edge/CHANGELOG.rst
index 301ca1d4d88..8309f111f6a 100644
--- a/providers/src/airflow/providers/edge/CHANGELOG.rst
+++ b/providers/src/airflow/providers/edge/CHANGELOG.rst
@@ -26,6 +26,15 @@
Changelog
---------
+
+0.8.0pre0
+.........
+
+Misc
+~~~~
+
+* ``Migrate worker registration and heartbeat to FastAPI.``
+
0.7.1pre0
.........
diff --git a/providers/src/airflow/providers/edge/__init__.py
b/providers/src/airflow/providers/edge/__init__.py
index 9b22a264d44..fd23acee829 100644
--- a/providers/src/airflow/providers/edge/__init__.py
+++ b/providers/src/airflow/providers/edge/__init__.py
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
__all__ = ["__version__"]
-__version__ = "0.7.1pre0"
+__version__ = "0.8.0pre0"
if
packaging.version.parse(packaging.version.parse(airflow_version).base_version)
< packaging.version.parse(
"2.10.0"
diff --git a/providers/src/airflow/providers/edge/cli/api_client.py
b/providers/src/airflow/providers/edge/cli/api_client.py
new file mode 100644
index 00000000000..9174191fd8c
--- /dev/null
+++ b/providers/src/airflow/providers/edge/cli/api_client.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+from datetime import datetime
+from http import HTTPStatus
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+from urllib.parse import quote, urljoin, urlparse
+
+import requests
+import tenacity
+from requests.exceptions import ConnectionError
+from urllib3.exceptions import NewConnectionError
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.edge.worker_api.auth import jwt_signer
+from airflow.providers.edge.worker_api.datamodels import WorkerStateBody
+
+if TYPE_CHECKING:
+ from airflow.providers.edge.models.edge_worker import EdgeWorkerState
+
+logger = logging.getLogger(__name__)
+
+
+def _is_retryable_exception(exception: BaseException) -> bool:
+ """
+ Evaluate which exception types to retry.
+
+ This is especially demanded for cases where an application gateway or
Kubernetes ingress can
+ not find a running instance of a webserver hosting the API (HTTP 502+504)
or when the
+ HTTP request fails in general on network level.
+
+ Note that we want to fail on other general errors on the webserver not to
send bad requests in an endless loop.
+ """
+ retryable_status_codes = (HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT)
+ return (
+ isinstance(exception, AirflowException)
+ and exception.status_code in retryable_status_codes
+ or isinstance(exception, (ConnectionError, NewConnectionError))
+ )
+
+
[email protected](
+ stop=tenacity.stop_after_attempt(10), # TODO: Make this configurable
+ wait=tenacity.wait_exponential(min=1), # TODO: Make this configurable
+ retry=tenacity.retry_if_exception(_is_retryable_exception),
+ before_sleep=tenacity.before_log(logger, logging.WARNING),
+)
+def _make_generic_request(method: str, rest_path: str, data: str) -> Any:
+ signer = jwt_signer()
+ api_url = conf.get("edge", "api_url")
+ path = urlparse(api_url).path.replace("/rpcapi", "")
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": signer.generate_signed_token({"method":
str(Path(path, rest_path))}),
+ }
+ api_endpoint = urljoin(api_url, rest_path)
+ response = requests.request(method, url=api_endpoint, data=data,
headers=headers)
+ if response.status_code == HTTPStatus.NO_CONTENT:
+ return None
+ if response.status_code != HTTPStatus.OK:
+ raise AirflowException(
+ f"Got {response.status_code}:{response.reason} when sending "
+ f"the internal api request: {response.text}",
+ HTTPStatus(response.status_code),
+ )
+ return json.loads(response.content)
+
+
+def worker_register(
+ hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo:
dict
+) -> datetime:
+ """Register worker with the Edge API."""
+ result = _make_generic_request(
+ "POST",
+ f"worker/{quote(hostname)}",
+ WorkerStateBody(state=state, jobs_active=0, queues=queues,
sysinfo=sysinfo).model_dump_json(
+ exclude_unset=True
+ ),
+ )
+ return datetime.fromisoformat(result)
+
+
+def worker_set_state(
+ hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str]
| None, sysinfo: dict
+) -> list[str] | None:
+ """Register worker with the Edge API."""
+ result = _make_generic_request(
+ "PATCH",
+ f"worker/{quote(hostname)}",
+ WorkerStateBody(state=state, jobs_active=jobs_active, queues=queues,
sysinfo=sysinfo).model_dump_json(
+ exclude_unset=True
+ ),
+ )
+ return result
diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py
b/providers/src/airflow/providers/edge/cli/edge_command.py
index 3712049b207..9d172bffdd5 100644
--- a/providers/src/airflow/providers/edge/cli/edge_command.py
+++ b/providers/src/airflow/providers/edge/cli/edge_command.py
@@ -14,13 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from __future__ import annotations
import logging
import os
import platform
import signal
+import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@@ -29,15 +29,17 @@ from time import sleep
import psutil
from lockfile.pidlockfile import read_pid_from_pidfile,
remove_existing_pidfile, write_pid_to_pidfile
+from packaging.version import Version
from airflow import __version__ as airflow_version, settings
from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.edge import __version__ as edge_provider_version
+from airflow.providers.edge.cli.api_client import worker_register,
worker_set_state
from airflow.providers.edge.models.edge_job import EdgeJob
from airflow.providers.edge.models.edge_logs import EdgeLogs
-from airflow.providers.edge.models.edge_worker import EdgeWorker,
EdgeWorkerState, EdgeWorkerVersionException
+from airflow.providers.edge.models.edge_worker import EdgeWorkerState,
EdgeWorkerVersionException
from airflow.utils import cli as cli_utils
from airflow.utils.platform import IS_WINDOWS
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
@@ -57,6 +59,45 @@ EDGE_WORKER_HEADER = "\n".join(
)
+@providers_configuration_loaded
+def force_use_internal_api_on_edge_worker():
+ """
+ Ensure that the environment is configured for the internal API without
needing to declare it outside.
+
+ This is only required for an Edge worker and must to be done before the
Click CLI wrapper is initiated.
+ That is because the CLI wrapper will attempt to establish a DB connection,
which will fail before the
+ function call can take effect. In an Edge worker, we need to "patch" the
environment before starting.
+ """
+ # export Edge API to be used for internal API
+ os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
+ os.environ["AIRFLOW_ENABLE_AIP_44"] = "True"
+ if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]:
+ AIRFLOW_VERSION = Version(airflow_version)
+ AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >=
Version("3.0.0")
+ if AIRFLOW_V_3_0_PLUS:
+ # Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
+ raise SystemExit(
+ "Error: EdgeWorker is currently broken on AIrflow 3/main due
to removal of AIP-44, rework for AIP-72."
+ )
+
+ api_url = conf.get("edge", "api_url")
+ if not api_url:
+ raise SystemExit("Error: API URL is not configured, please correct
configuration.")
+ logger.info("Starting worker with API endpoint %s", api_url)
+ os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url
+
+ from airflow.api_internal import internal_api_call
+ from airflow.serialization import serialized_objects
+
+ # Note: Need to patch internal settings as statically initialized
before we get here
+ serialized_objects._ENABLE_AIP_44 = True
+ internal_api_call._ENABLE_AIP_44 = True
+ internal_api_call.InternalApiConfig.set_use_internal_api("edge-worker")
+
+
+force_use_internal_api_on_edge_worker()
+
+
def _hostname() -> str:
if IS_WINDOWS:
return platform.uname().node
@@ -153,9 +194,9 @@ class _EdgeWorkerCli:
def start(self):
"""Start the execution in a loop until terminated."""
try:
- self.last_hb = EdgeWorker.register_worker(
+ self.last_hb = worker_register(
self.hostname, EdgeWorkerState.STARTING, self.queues,
self._get_sysinfo()
- ).last_update
+ )
except EdgeWorkerVersionException as e:
logger.info("Version mismatch of Edge worker and Core. Shutting
down worker.")
raise SystemExit(str(e))
@@ -172,7 +213,7 @@ class _EdgeWorkerCli:
logger.info("Quitting worker, signal being offline.")
try:
- EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE,
0, self._get_sysinfo())
+ worker_set_state(self.hostname, EdgeWorkerState.OFFLINE, 0,
self.queues, self._get_sysinfo())
except EdgeWorkerVersionException:
logger.info("Version mismatch of Edge worker and Core.
Quitting worker anyway.")
finally:
@@ -261,7 +302,7 @@ class _EdgeWorkerCli:
)
sysinfo = self._get_sysinfo()
try:
- self.queues = EdgeWorker.set_state(self.hostname, state,
len(self.jobs), sysinfo)
+ self.queues = worker_set_state(self.hostname, state,
len(self.jobs), self.queues, sysinfo)
except EdgeWorkerVersionException:
logger.info("Version mismatch of Edge worker and Core. Shutting
down worker.")
_EdgeWorkerCli.drain = True
diff --git a/providers/src/airflow/providers/edge/executors/edge_executor.py
b/providers/src/airflow/providers/edge/executors/edge_executor.py
index 48ae5e872e0..4184a8ffe5b 100644
--- a/providers/src/airflow/providers/edge/executors/edge_executor.py
+++ b/providers/src/airflow/providers/edge/executors/edge_executor.py
@@ -33,7 +33,7 @@ from airflow.models.taskinstance import TaskInstance,
TaskInstanceState
from airflow.providers.edge.cli.edge_command import EDGE_COMMANDS
from airflow.providers.edge.models.edge_job import EdgeJobModel
from airflow.providers.edge.models.edge_logs import EdgeLogsModel
-from airflow.providers.edge.models.edge_worker import EdgeWorker,
EdgeWorkerModel, EdgeWorkerState
+from airflow.providers.edge.models.edge_worker import EdgeWorkerModel,
EdgeWorkerState, reset_metrics
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.db import DBLocks, create_global_lock
@@ -145,7 +145,7 @@ class EdgeExecutor(BaseExecutor):
for worker in lifeless_workers:
changed = True
worker.state = EdgeWorkerState.UNKNOWN
- EdgeWorker.reset_metrics(worker.worker_name)
+ reset_metrics(worker.worker_name)
return changed
diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py
b/providers/src/airflow/providers/edge/models/edge_worker.py
index a1287fdb96c..656d7539d07 100644
--- a/providers/src/airflow/providers/edge/models/edge_worker.py
+++ b/providers/src/airflow/providers/edge/models/edge_worker.py
@@ -23,12 +23,7 @@ from enum import Enum
from typing import TYPE_CHECKING, Optional
from pydantic import BaseModel, ConfigDict
-from sqlalchemy import (
- Column,
- Integer,
- String,
- select,
-)
+from sqlalchemy import Column, Integer, String, select
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException
@@ -129,8 +124,56 @@ class EdgeWorkerModel(Base, LoggingMixin):
self.queues = queues
+def set_metrics(
+ worker_name: str,
+ state: EdgeWorkerState,
+ jobs_active: int,
+ concurrency: int,
+ free_concurrency: int,
+ queues: list[str] | None,
+) -> None:
+ """Set metric of edge worker."""
+ queues = queues if queues else []
+ connected = state not in (EdgeWorkerState.UNKNOWN, EdgeWorkerState.OFFLINE)
+
+ Stats.gauge(f"edge_worker.state.{worker_name}", int(connected))
+ Stats.gauge(
+ "edge_worker.state",
+ int(connected),
+ tags={"name": worker_name, "state": state},
+ )
+
+ Stats.gauge(f"edge_worker.jobs_active.{worker_name}", jobs_active)
+ Stats.gauge("edge_worker.jobs_active", jobs_active, tags={"worker_name":
worker_name})
+
+ Stats.gauge(f"edge_worker.concurrency.{worker_name}", concurrency)
+ Stats.gauge("edge_worker.concurrency", concurrency, tags={"worker_name":
worker_name})
+
+ Stats.gauge(f"edge_worker.free_concurrency.{worker_name}",
free_concurrency)
+ Stats.gauge("edge_worker.free_concurrency", free_concurrency,
tags={"worker_name": worker_name})
+
+ Stats.gauge(f"edge_worker.num_queues.{worker_name}", len(queues))
+ Stats.gauge(
+ "edge_worker.num_queues",
+ len(queues),
+ tags={"worker_name": worker_name, "queues": ",".join(queues)},
+ )
+
+
+def reset_metrics(worker_name: str) -> None:
+ """Reset metrics of worker."""
+ set_metrics(
+ worker_name=worker_name,
+ state=EdgeWorkerState.UNKNOWN,
+ jobs_active=0,
+ concurrency=0,
+ free_concurrency=-1,
+ queues=None,
+ )
+
+
class EdgeWorker(BaseModel, LoggingMixin):
- """Accessor for Edge Worker instances as logical model."""
+ """Deprecated Edge Worker internal API, keeping for one minor for graceful
migration."""
worker_name: str
state: EdgeWorkerState
@@ -144,119 +187,6 @@ class EdgeWorker(BaseModel, LoggingMixin):
sysinfo: str
model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
- @staticmethod
- def set_metrics(
- worker_name: str,
- state: EdgeWorkerState,
- jobs_active: int,
- concurrency: int,
- free_concurrency: int,
- queues: list[str] | None,
- ) -> None:
- """Set metric of edge worker."""
- queues = queues if queues else []
- connected = state not in (EdgeWorkerState.UNKNOWN,
EdgeWorkerState.OFFLINE)
- Stats.gauge(f"edge_worker.state.{worker_name}", int(connected))
- Stats.gauge(
- "edge_worker.state",
- int(connected),
- tags={"name": worker_name, "state": state},
- )
-
- Stats.gauge(f"edge_worker.jobs_active.{worker_name}", jobs_active)
- Stats.gauge("edge_worker.jobs_active", jobs_active,
tags={"worker_name": worker_name})
-
- Stats.gauge(f"edge_worker.concurrency.{worker_name}", concurrency)
- Stats.gauge("edge_worker.concurrency", concurrency,
tags={"worker_name": worker_name})
-
- Stats.gauge(f"edge_worker.free_concurrency.{worker_name}",
free_concurrency)
- Stats.gauge("edge_worker.free_concurrency", free_concurrency,
tags={"worker_name": worker_name})
-
- Stats.gauge(
- f"edge_worker.num_queues.{worker_name}",
- len(queues),
- )
- Stats.gauge(
- "edge_worker.num_queues",
- len(queues),
- tags={"worker_name": worker_name, "queues": ",".join(queues)},
- )
-
- @staticmethod
- def reset_metrics(worker_name: str) -> None:
- """Reset metrics of worker."""
- EdgeWorker.set_metrics(
- worker_name=worker_name,
- state=EdgeWorkerState.UNKNOWN,
- jobs_active=0,
- concurrency=0,
- free_concurrency=-1,
- queues=None,
- )
-
- @staticmethod
- def assert_version(sysinfo: dict[str, str]) -> None:
- """Check if the Edge Worker version matches the central API site."""
- from airflow import __version__ as airflow_version
- from airflow.providers.edge import __version__ as edge_provider_version
-
- # Note: In future, more stable versions we might be more liberate, for
the
- # moment we require exact version match for Edge Worker and core
version
- if "airflow_version" in sysinfo:
- airflow_on_worker = sysinfo["airflow_version"]
- if airflow_on_worker != airflow_version:
- raise EdgeWorkerVersionException(
- f"Edge Worker runs on Airflow {airflow_on_worker} "
- f"and the core runs on {airflow_version}. Rejecting access
due to difference."
- )
- else:
- raise EdgeWorkerVersionException("Edge Worker does not specify the
version it is running on.")
-
- if "edge_provider_version" in sysinfo:
- provider_on_worker = sysinfo["edge_provider_version"]
- if provider_on_worker != edge_provider_version:
- raise EdgeWorkerVersionException(
- f"Edge Worker runs on Edge Provider {provider_on_worker} "
- f"and the core runs on {edge_provider_version}. Rejecting
access due to difference."
- )
- else:
- raise EdgeWorkerVersionException(
- "Edge Worker does not specify the provider version it is
running on."
- )
-
- @staticmethod
- @internal_api_call
- @provide_session
- def register_worker(
- worker_name: str,
- state: EdgeWorkerState,
- queues: list[str] | None,
- sysinfo: dict[str, str],
- session: Session = NEW_SESSION,
- ) -> EdgeWorker:
- EdgeWorker.assert_version(sysinfo)
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
- if not worker:
- worker = EdgeWorkerModel(worker_name=worker_name, state=state,
queues=queues)
- worker.state = state
- worker.queues = queues
- worker.sysinfo = json.dumps(sysinfo)
- worker.last_update = timezone.utcnow()
- session.add(worker)
- return EdgeWorker(
- worker_name=worker_name,
- state=state,
- queues=queues,
- first_online=worker.first_online,
- last_update=worker.last_update,
- jobs_active=worker.jobs_active or 0,
- jobs_taken=worker.jobs_taken or 0,
- jobs_success=worker.jobs_success or 0,
- jobs_failed=worker.jobs_failed or 0,
- sysinfo=worker.sysinfo or "{}",
- )
-
@staticmethod
@internal_api_call
@provide_session
@@ -277,7 +207,7 @@ class EdgeWorker(BaseModel, LoggingMixin):
session.commit()
Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name":
worker_name})
- EdgeWorker.set_metrics(
+ set_metrics(
worker_name=worker_name,
state=state,
jobs_active=jobs_active,
@@ -285,25 +215,21 @@ class EdgeWorker(BaseModel, LoggingMixin):
free_concurrency=int(sysinfo["free_concurrency"]),
queues=worker.queues,
)
- EdgeWorker.assert_version(sysinfo) # Exception only after worker
state is in the DB
- return worker.queues
+ raise EdgeWorkerVersionException(
+ "Edge Worker runs on an old version. Rejecting access due to
difference."
+ )
@staticmethod
- @provide_session
- def add_and_remove_queues(
+ @internal_api_call
+ def register_worker(
worker_name: str,
- new_queues: list[str] | None = None,
- remove_queues: list[str] | None = None,
- session: Session = NEW_SESSION,
- ) -> None:
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
- if new_queues:
- worker.add_queues(new_queues)
- if remove_queues:
- worker.remove_queues(remove_queues)
- session.add(worker)
- session.commit()
+ state: EdgeWorkerState,
+ queues: list[str] | None,
+ sysinfo: dict[str, str],
+ ) -> EdgeWorker:
+ raise EdgeWorkerVersionException(
+ "Edge Worker runs on an old version. Rejecting access due to
difference."
+ )
EdgeWorker.model_rebuild()
diff --git
a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
index f1ab5f4c05e..8be23c0d07c 100644
--- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
+++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
@@ -35,11 +35,154 @@ servers:
- url: /edge_worker/v1
description: Airflow Edge Worker API
paths:
- "/rpcapi":
+ /worker/{worker_name}:
+ patch:
+ description: Set state of worker and returns the current assigned queues.
+ x-openapi-router-controller:
airflow.providers.edge.worker_api.routes._v2_routes
+ operationId: set_state_v2
+ parameters:
+ - description: Hostname or instance name of the worker
+ in: path
+ name: worker_name
+ required: true
+ schema:
+ description: Hostname or instance name of the worker
+ title: Worker Name
+ type: string
+ - description: JWT Authorization Token
+ in: header
+ name: authorization
+ required: true
+ schema:
+ description: JWT Authorization Token
+ title: Authorization
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/WorkerStateBody'
+ description: State of the worker with details
+ examples:
+ - jobs_active: 3
+ queues:
+ - large_node
+ - wisconsin_site
+ state: running
+ sysinfo:
+ airflow_version: 2.10.0
+ concurrency: 4
+ edge_provider_version: 1.0.0
+ title: Worker State
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ anyOf:
+ - items:
+ type: string
+ type: array
+ - type: object
+ nullable: true
+ title: Response Set State
+ description: Successful Response
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ description: Validation Error
+ summary: Set State
+ tags:
+ - Worker
+ post:
+ description: Register a new worker to the backend.
+ x-openapi-router-controller:
airflow.providers.edge.worker_api.routes._v2_routes
+ operationId: register_v2
+ parameters:
+ - description: Hostname or instance name of the worker
+ in: path
+ name: worker_name
+ required: true
+ schema:
+ description: Hostname or instance name of the worker
+ title: Worker Name
+ type: string
+ - description: JWT Authorization Token
+ in: header
+ name: authorization
+ required: true
+ schema:
+ description: JWT Authorization Token
+ title: Authorization
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/WorkerStateBody'
+ description: State of the worker with details
+ examples:
+ - jobs_active: 3
+ queues:
+ - large_node
+ - wisconsin_site
+ state: running
+ sysinfo:
+ airflow_version: 2.10.0
+ concurrency: 4
+ edge_provider_version: 1.0.0
+ title: Worker State
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ format: date-time
+ title: Response Register
+ type: string
+ description: Successful Response
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ description: Validation Error
+ summary: Register
+ tags:
+ - Worker
+ /rpcapi:
post:
deprecated: false
- x-openapi-router-controller:
airflow.providers.edge.worker_api.routes.rpc_api
- operationId: edge_worker_api_v2
+ x-openapi-router-controller:
airflow.providers.edge.worker_api.routes._v2_routes
+ operationId: rpcapi_v2
tags:
- JSONRPC
parameters: []
@@ -68,7 +211,7 @@ paths:
params:
title: Parameters
type: object
- "/health":
+ /health:
get:
operationId: health
deprecated: false
@@ -99,4 +242,89 @@ components:
description: JSON-RPC Version (2.0)
discriminator:
propertyName: method_name
+ EdgeWorkerState:
+ description: Status of a Edge Worker instance.
+ enum:
+ - starting
+ - running
+ - idle
+ - terminating
+ - offline
+ - unknown
+ title: EdgeWorkerState
+ type: string
+ WorkerStateBody:
+ description: Details of the worker state sent to the scheduler.
+ type: object
+ required:
+ - state
+ - queues
+ - sysinfo
+ properties:
+ jobs_active:
+ default: 0
+ description: Number of active jobs the worker is running.
+ title: Jobs Active
+ type: integer
+ queues:
+ anyOf:
+ - items:
+ type: string
+ type: array
+ - type: object
+ nullable: true
+ description: List of queues the worker is pulling jobs from. If not
provided,
+ worker pulls from all queues.
+ title: Queues
+ state:
+ $ref: '#/components/schemas/EdgeWorkerState'
+ description: State of the worker from the view of the worker.
+ sysinfo:
+ description: System information of the worker.
+ title: Sysinfo
+ type: object
+ title: WorkerStateBody
+ HTTPExceptionResponse:
+ description: HTTPException Model used for error response.
+ properties:
+ detail:
+ anyOf:
+ - type: string
+ - type: object
+ title: Detail
+ required:
+ - detail
+ title: HTTPExceptionResponse
+ type: object
+ HTTPValidationError:
+ properties:
+ detail:
+ items:
+ $ref: '#/components/schemas/ValidationError'
+ title: Detail
+ type: array
+ title: HTTPValidationError
+ type: object
+ ValidationError:
+ properties:
+ loc:
+ items:
+ anyOf:
+ - type: string
+ - type: integer
+ title: Location
+ type: array
+ msg:
+ title: Message
+ type: string
+ type:
+ title: Error Type
+ type: string
+ required:
+ - loc
+ - msg
+ - type
+ title: ValidationError
+ type: object
+
tags: []
diff --git a/providers/src/airflow/providers/edge/provider.yaml
b/providers/src/airflow/providers/edge/provider.yaml
index 95827a44e9b..25dd75a2624 100644
--- a/providers/src/airflow/providers/edge/provider.yaml
+++ b/providers/src/airflow/providers/edge/provider.yaml
@@ -27,7 +27,7 @@ source-date-epoch: 1729683247
# note that those versions are maintained by release manager - do not update
them manually
versions:
- - 0.7.1pre0
+ - 0.8.0pre0
dependencies:
- apache-airflow>=2.10.0
diff --git a/providers/src/airflow/providers/edge/worker_api/app.py
b/providers/src/airflow/providers/edge/worker_api/app.py
index bfe9ef4c5bc..69a43edb116 100644
--- a/providers/src/airflow/providers/edge/worker_api/app.py
+++ b/providers/src/airflow/providers/edge/worker_api/app.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from fastapi import FastAPI
from airflow.providers.edge.worker_api.routes.health import health_router
-from airflow.providers.edge.worker_api.routes.rpc_api import rpc_api_router
+from airflow.providers.edge.worker_api.routes.worker import worker_router
def create_edge_worker_api_app() -> FastAPI:
@@ -35,5 +35,5 @@ def create_edge_worker_api_app() -> FastAPI:
)
edge_worker_api_app.include_router(health_router)
- edge_worker_api_app.include_router(rpc_api_router)
+ edge_worker_api_app.include_router(worker_router)
return edge_worker_api_app
diff --git a/providers/src/airflow/providers/edge/worker_api/auth.py
b/providers/src/airflow/providers/edge/worker_api/auth.py
new file mode 100644
index 00000000000..5829e94732b
--- /dev/null
+++ b/providers/src/airflow/providers/edge/worker_api/auth.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from functools import cache
+from uuid import uuid4
+
+from itsdangerous import BadSignature
+from jwt import (
+ ExpiredSignatureError,
+ ImmatureSignatureError,
+ InvalidAudienceError,
+ InvalidIssuedAtError,
+ InvalidSignatureError,
+)
+
+from airflow.configuration import conf
+from airflow.providers.edge.worker_api.datamodels import JsonRpcRequestBase #
noqa: TCH001
+from airflow.providers.edge.worker_api.routes._v2_compat import (
+ Header,
+ HTTPException,
+ Request,
+ status,
+)
+from airflow.utils.jwt_signer import JWTSigner
+
+log = logging.getLogger(__name__)
+
+
+@cache
+def jwt_signer() -> JWTSigner:
+ clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30)
+ return JWTSigner(
+ secret_key=conf.get("core", "internal_api_secret_key"),
+ expiration_time_in_seconds=clock_grace,
+ leeway_in_seconds=clock_grace,
+ audience="api",
+ )
+
+
+def _forbidden_response(message: str):
+ """Log the error and return the response anonymized."""
+ error_id = uuid4()
+ log.exception("%s error_id=%s", message, error_id)
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN,
+ f"Forbidden. The server side traceback may be identified with
error_id={error_id}",
+ )
+
+
+def jwt_token_authorization(method: str, authorization: str):
+ """Check if the JWT token is correct."""
+ try:
+ payload = jwt_signer().verify_token(authorization)
+ signed_method = payload.get("method")
+ if not signed_method or signed_method != method:
+ _forbidden_response(
+ "Invalid method in token authorization. "
+ f"signed method='{signed_method}' "
+ f"called method='{method}'",
+ )
+ except BadSignature:
+ _forbidden_response("Bad Signature. Please use only the tokens
provided by the API.")
+ except InvalidAudienceError:
+ _forbidden_response("Invalid audience for the request")
+ except InvalidSignatureError:
+ _forbidden_response("The signature of the request was wrong")
+ except ImmatureSignatureError:
+ _forbidden_response("The signature of the request was sent from the
future")
+ except ExpiredSignatureError:
+ _forbidden_response(
+ "The signature of the request has expired. Make sure that all
components "
+ "in your system have synchronized clocks.",
+ )
+ except InvalidIssuedAtError:
+ _forbidden_response(
+ "The request was issues in the future. Make sure that all
components "
+ "in your system have synchronized clocks.",
+ )
+ except Exception:
+ _forbidden_response("Unable to authenticate API via token.")
+
+
+def jwt_token_authorization_rpc(
+ body: JsonRpcRequestBase, authorization: str = Header(description="JWT
Authorization Token")
+):
+ """Check if the JWT token is correct for JSON PRC requests."""
+ jwt_token_authorization(body.method, authorization)
+
+
+def jwt_token_authorization_rest(
+ request: Request, authorization: str = Header(description="JWT
Authorization Token")
+):
+ """Check if the JWT token is correct for REST API requests."""
+ jwt_token_authorization(request.url.path, authorization)
diff --git a/providers/src/airflow/providers/edge/worker_api/datamodels.py
b/providers/src/airflow/providers/edge/worker_api/datamodels.py
index 9ce181bc726..170d8c449ff 100644
--- a/providers/src/airflow/providers/edge/worker_api/datamodels.py
+++ b/providers/src/airflow/providers/edge/worker_api/datamodels.py
@@ -16,17 +16,73 @@
# under the License.
from __future__ import annotations
-from typing import Any, Optional
+from typing import ( # noqa: UP035 - prevent pytest failing in back-compat
+ Annotated,
+ Any,
+ Dict,
+ List,
+ Optional,
+ Union,
+)
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
+from airflow.providers.edge.models.edge_worker import EdgeWorkerState # noqa:
TCH001
-class JsonRpcRequest(BaseModel):
+
+class JsonRpcRequestBase(BaseModel):
+ """Base JSON RPC request model to define just the method."""
+
+ method: Annotated[
+ str,
+ Field(description="Fully qualified python module method name that is
called via JSON RPC."),
+ ]
+
+
+class JsonRpcRequest(JsonRpcRequestBase):
"""JSON RPC request model."""
- method: str
- """Fully qualified python module method name that is called via JSON
RPC."""
- jsonrpc: str
- """JSON RPC version."""
- params: Optional[dict[str, Any]] = None # noqa: UP007 - prevent pytest
failing in back-compat
- """Parameters passed to the method."""
+ jsonrpc: Annotated[str, Field(description="JSON RPC Version",
examples=["2.0"])]
+ params: Annotated[
+ Optional[Dict[str, Any]], # noqa: UP006, UP007 - prevent pytest
failing in back-compat
+ Field(description="Dictionary of parameters passed to the method."),
+ ]
+
+
+class WorkerStateBody(BaseModel):
+ """Details of the worker state sent to the scheduler."""
+
+ state: Annotated[EdgeWorkerState, Field(description="State of the worker
from the view of the worker.")]
+ jobs_active: Annotated[int, Field(description="Number of active jobs the
worker is running.")] = 0
+ queues: Annotated[
+ Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in
back-compat
+ Field(
+ description="List of queues the worker is pulling jobs from. If
not provided, worker pulls from all queues."
+ ),
+ ] = None
+ sysinfo: Annotated[
+ Dict[str, Union[str, int]], # noqa: UP006, UP007 - prevent pytest
failing in back-compat
+ Field(
+ description="System information of the worker.",
+ examples=[
+ {
+ "concurrency": 4,
+ "airflow_version": "2.0.0",
+ "edge_provider_version": "1.0.0",
+ }
+ ],
+ ),
+ ]
+
+
+class WorkerQueueUpdateBody(BaseModel):
+ """Changed queues for the worker."""
+
+ new_queues: Annotated[
+ Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in
back-compat
+ Field(description="Additional queues to be added to worker."),
+ ]
+ remove_queues: Annotated[
+ Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in
back-compat
+ Field(description="Queues to remove from worker."),
+ ]
diff --git
a/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py
index 9774b3e6696..553456d4108 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py
@@ -27,8 +27,9 @@ AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >=
Version("3.0.0")
if AIRFLOW_V_3_0_PLUS:
# Just re-import the types from FastAPI and Airflow Core
- from fastapi import Depends, Header, HTTPException, status
+ from fastapi import Body, Depends, Header, HTTPException, Path, Request,
status
+ from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
else:
@@ -37,17 +38,33 @@ else:
from connexion import ProblemException
+ class Body: # type: ignore[no-redef]
+ def __init__(self, *_, **__):
+ pass
+
class Depends: # type: ignore[no-redef]
def __init__(self, *_, **__):
pass
class Header: # type: ignore[no-redef]
+ def __init__(self, *_, **__):
+ pass
+
+ class Path: # type: ignore[no-redef]
+ def __init__(self, *_, **__):
+ pass
+
+ class Request: # type: ignore[no-redef]
+ pass
+
+ class SessionDep: # type: ignore[no-redef]
pass
def create_openapi_http_exception_doc(responses_status_code: list[int]) ->
dict:
return {}
class status: # type: ignore[no-redef]
+ HTTP_204_NO_CONTENT = 204
HTTP_400_BAD_REQUEST = 400
HTTP_403_FORBIDDEN = 403
HTTP_500_INTERNAL_SERVER_ERROR = 500
@@ -100,3 +117,9 @@ else:
return func
return decorator
+
+ def patch(self, *_, **__):
+ def decorator(func: Callable) -> Callable:
+ return func
+
+ return decorator
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
similarity index 52%
rename from providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
rename to providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index aa5b30f5ab7..6f2e81caa00 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Compatibility layer for Connexion API to Airflow v2.10 API routes."""
from __future__ import annotations
@@ -23,45 +24,39 @@ from functools import cache
from typing import TYPE_CHECKING, Any, Callable
from uuid import uuid4
-from itsdangerous import BadSignature
-from jwt import (
- ExpiredSignatureError,
- ImmatureSignatureError,
- InvalidAudienceError,
- InvalidIssuedAtError,
- InvalidSignatureError,
-)
-
-from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest
-from airflow.providers.edge.worker_api.routes._v2_compat import (
- AirflowRouter,
- Depends,
- Header,
- HTTPException,
- create_openapi_http_exception_doc,
- status,
-)
+from airflow.providers.edge.worker_api.auth import jwt_token_authorization,
jwt_token_authorization_rpc
+from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest,
WorkerStateBody
+from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException,
status
+from airflow.providers.edge.worker_api.routes.worker import register, set_state
from airflow.serialization.serialized_objects import BaseSerialization
-from airflow.utils.jwt_signer import JWTSigner
-from airflow.utils.session import create_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
if TYPE_CHECKING:
from airflow.api_connexion.types import APIResponse
+
log = logging.getLogger(__name__)
-rpc_api_router = AirflowRouter(tags=["JSONRPC"])
@cache
def _initialize_method_map() -> dict[str, Callable]:
+ # Note: This is a copy of the (removed) AIP-44 implementation from
+ # airflow/api_internal/endpoints/rpc_api_endpoint.py
+ # for compatibility with Airflow 2.10-line.
+ # Methods are potentially not existing more on main branch for
Airflow 3.
+ from airflow.api.common.trigger_dag import trigger_dag
from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
+
+ # Airflow 2.10 compatibility
+ from airflow.datasets import expand_alias_to_datasets # type:
ignore[attr-defined]
+ from airflow.datasets.manager import DatasetManager # type:
ignore[attr-defined]
from airflow.jobs.job import Job, most_recent_job
from airflow.models import Trigger, Variable, XCom
from airflow.models.dag import DAG, DagModel
+ from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -76,14 +71,13 @@ def _initialize_method_map() -> dict[str, Callable]:
_handle_reschedule,
_record_task_map_for_downstreams,
_update_rtif,
- _update_ti_heartbeat,
_xcom_pull,
)
from airflow.models.xcom_arg import _get_task_map_length
from airflow.providers.edge.models.edge_job import EdgeJob
from airflow.providers.edge.models.edge_logs import EdgeLogs
from airflow.providers.edge.models.edge_worker import EdgeWorker
- from airflow.sdk.definitions.asset import expand_alias_to_assets
+ from airflow.secrets.metastore import MetastoreBackend
from airflow.sensors.base import _orig_start_date
from airflow.utils.cli_action_loggers import _default_action_log_internal
from airflow.utils.log.file_task_handler import FileTaskHandler
@@ -95,19 +89,22 @@ def _initialize_method_map() -> dict[str, Callable]:
_get_ti_db_access,
_get_task_map_length,
_update_rtif,
- _update_ti_heartbeat,
_orig_start_date,
_handle_failure,
_handle_reschedule,
_add_log,
_xcom_pull,
_record_task_map_for_downstreams,
+ trigger_dag,
+ DagCode.remove_deleted_code,
DagModel.deactivate_deleted_dags,
DagModel.get_paused_dag_ids,
DagModel.get_current,
DagFileProcessor._execute_task_callbacks,
DagFileProcessor.execute_callbacks,
DagFileProcessor.execute_callbacks_without_dag,
+ # Airflow 2.10 compatibility
+ DagFileProcessor.manage_slas, # type: ignore[attr-defined]
DagFileProcessor.save_dag_to_db,
DagFileProcessor.update_import_errors,
DagFileProcessor._validate_task_pools_and_update_dag_warnings,
@@ -116,13 +113,18 @@ def _initialize_method_map() -> dict[str, Callable]:
DagFileProcessorManager.clear_nonexistent_import_errors,
DagFileProcessorManager.deactivate_stale_dags,
DagWarning.purge_inactive_dag_warnings,
- expand_alias_to_assets,
+ expand_alias_to_datasets,
+ DatasetManager.register_dataset_change,
FileTaskHandler._render_filename_db_access,
Job._add_to_db,
+ Job._fetch_from_db,
Job._kill,
Job._update_heartbeat,
Job._update_in_db,
most_recent_job,
+ # Airflow 2.10 compatibility
+ MetastoreBackend._fetch_connection, # type: ignore[attr-defined]
+ MetastoreBackend._fetch_variable, # type: ignore[attr-defined]
XCom.get_value,
XCom.get_one,
# XCom.get_many, # Not supported because it returns query
@@ -141,6 +143,7 @@ def _initialize_method_map() -> dict[str, Callable]:
DagRun._get_log_template,
RenderedTaskInstanceFields._update_runtime_evaluated_template_fields,
SerializedDagModel.get_serialized_dag,
+ SerializedDagModel.remove_deleted_dags,
SkipMixin._skip,
SkipMixin._skip_all_except,
TaskInstance._check_and_change_state_before_execution,
@@ -149,7 +152,6 @@ def _initialize_method_map() -> dict[str, Callable]:
TaskInstance._set_state,
TaskInstance.save_to_db,
TaskInstance._clear_xcom_data,
- TaskInstance._register_asset_changes_int,
Trigger.from_object,
Trigger.bulk_fetch,
Trigger.clean_unused,
@@ -158,6 +160,7 @@ def _initialize_method_map() -> dict[str, Callable]:
Trigger.ids_for_triggerer,
Trigger.assign_unassigned,
# Additional things from EdgeExecutor
+ # These are removed in follow-up PRs as being in transition to FastAPI
EdgeJob.reserve_task,
EdgeJob.set_state,
EdgeLogs.push_logs,
@@ -167,17 +170,6 @@ def _initialize_method_map() -> dict[str, Callable]:
return {f"{func.__module__}.{func.__qualname__}": func for func in
functions}
-@cache
-def _jwt_signer() -> JWTSigner:
- clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30)
- return JWTSigner(
- secret_key=conf.get("core", "internal_api_secret_key"),
- expiration_time_in_seconds=clock_grace,
- leeway_in_seconds=clock_grace,
- audience="api",
- )
-
-
def error_response(message: str, status: int):
"""Log the error and return the response as JSON object."""
error_id = uuid4()
@@ -187,124 +179,92 @@ def error_response(message: str, status: int):
return HTTPException(status, client_message)
-def json_request_headers(content_type: str = Header(), accept: str = Header()):
- """Check if the request headers are correct."""
- if content_type != "application/json":
- raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Content-Type:
application/json")
- if accept != "application/json":
- raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Accept:
application/json")
-
+def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
+ """Handle Edge Worker API `/edge_worker/v1/rpcapi` endpoint for Airflow
2.10."""
+ # Note: Except the method map this _was_ a 100% copy of internal API module
+ #
airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api()
+ # As of rework for FastAPI in Airflow 3.0, this is updated and to be
removed in the future.
+ from flask import Response, request
-def jwt_token_authorization(body: JsonRpcRequest, authorization: str =
Header()):
- """Check if the JWT token is correct."""
try:
- payload = _jwt_signer().verify_token(authorization)
- signed_method = payload.get("method")
- if not signed_method or signed_method != body.method:
- raise BadSignature("Invalid method in token authorization.")
- except BadSignature:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN, "Bad Signature. Please use only the
tokens provided by the API."
- )
- except InvalidAudienceError:
- raise HTTPException(status.HTTP_403_FORBIDDEN, "Invalid audience for
the request")
- except InvalidSignatureError:
- raise HTTPException(status.HTTP_403_FORBIDDEN, "The signature of the
request was wrong")
- except ImmatureSignatureError:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN, "The signature of the request was sent
from the future"
- )
- except ExpiredSignatureError:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN,
- "The signature of the request has expired. Make sure that all
components "
- "in your system have synchronized clocks.",
- )
- except InvalidIssuedAtError:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN,
- "The request was issues in the future. Make sure that all
components "
- "in your system have synchronized clocks.",
- )
- except Exception:
- raise HTTPException(status.HTTP_403_FORBIDDEN, "Unable to authenticate
API via token.")
+ if request.headers.get("Content-Type", "") != "application/json":
+ raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected
Content-Type: application/json")
+ if request.headers.get("Accept", "") != "application/json":
+ raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Accept:
application/json")
+ auth = request.headers.get("Authorization", "")
+ request_obj = JsonRpcRequest(method=body["method"],
jsonrpc=body["jsonrpc"], params=body["params"])
+ jwt_token_authorization_rpc(request_obj, auth)
+ if request_obj.jsonrpc != "2.0":
+ raise error_response("Expected jsonrpc 2.0 request.",
status.HTTP_400_BAD_REQUEST)
+ log.debug("Got request for %s", request_obj.method)
+ methods_map = _initialize_method_map()
+ if request_obj.method not in methods_map:
+ raise error_response(f"Unrecognized method:
{request_obj.method}.", status.HTTP_400_BAD_REQUEST)
-def json_rpc_version(body: JsonRpcRequest):
- """Check if the JSON RPC Request version is correct."""
- if body.jsonrpc != "2.0":
- raise error_response("Expected jsonrpc 2.0 request.",
status.HTTP_400_BAD_REQUEST)
+ handler = methods_map[request_obj.method]
+ params = {}
+ try:
+ if request_obj.params:
+ params = BaseSerialization.deserialize(request_obj.params,
use_pydantic_models=True)
+ except Exception:
+ raise error_response("Error deserializing parameters.",
status.HTTP_400_BAD_REQUEST)
+ log.debug("Calling method %s\nparams: %s", request_obj.method, params)
+ try:
+ # Session must be created there as it may be needed by serializer
for lazy-loaded fields.
+ with create_session() as session:
+ output = handler(**params, session=session)
+ output_json = BaseSerialization.serialize(output,
use_pydantic_models=True)
+ log.debug(
+ "Sending response: %s", json.dumps(output_json) if
output_json is not None else None
+ )
+ # In case of AirflowException or other selective known types,
transport the exception class back to caller
+ except (KeyError, AttributeError, AirflowException) as e:
+ output_json = BaseSerialization.serialize(e,
use_pydantic_models=True)
+ log.debug(
+ "Sending exception response: %s", json.dumps(output_json) if
output_json is not None else None
+ )
+ except Exception:
+ raise error_response(
+ f"Error executing method '{request_obj.method}'.",
status.HTTP_500_INTERNAL_SERVER_ERROR
+ )
+ response = json.dumps(output_json) if output_json is not None else None
+ return Response(response=response, headers={"Content-Type":
"application/json"})
+ except HTTPException as e:
+ return e.to_response() # type: ignore[attr-defined]
-@rpc_api_router.post(
- "/rpcapi",
- dependencies=[Depends(json_request_headers),
Depends(jwt_token_authorization), Depends(json_rpc_version)],
- responses=create_openapi_http_exception_doc(
- [
- status.HTTP_400_BAD_REQUEST,
- status.HTTP_403_FORBIDDEN,
- status.HTTP_500_INTERNAL_SERVER_ERROR,
- ]
- ),
-)
-def rpcapi(body: JsonRpcRequest) -> Any | None:
- """Handle Edge Worker API calls as JSON-RPC."""
- log.debug("Got request for %s", body.method)
- methods_map = _initialize_method_map()
- if body.method not in methods_map:
- raise error_response(f"Unrecognized method: {body.method}.",
status.HTTP_400_BAD_REQUEST)
- handler = methods_map[body.method]
- params = {}
- try:
- if body.params:
- params = BaseSerialization.deserialize(body.params,
use_pydantic_models=True)
- except Exception:
- raise error_response("Error deserializing parameters.",
status.HTTP_400_BAD_REQUEST)
+@provide_session
+def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION)
-> Any:
+ """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint
for Airflow 2.10."""
+ from flask import request
- log.debug("Calling method %s\nparams: %s", body.method, params)
try:
- # Session must be created there as it may be needed by serializer for
lazy-loaded fields.
- with create_session() as session:
- output = handler(**params, session=session)
- output_json = BaseSerialization.serialize(output,
use_pydantic_models=True)
- log.debug("Sending response: %s", json.dumps(output_json) if
output_json is not None else None)
- return output_json
- # In case of AirflowException or other selective known types, transport
the exception class back to caller
- except (KeyError, AttributeError, AirflowException) as e:
- exception_json = BaseSerialization.serialize(e,
use_pydantic_models=True)
- log.debug(
- "Sending exception response: %s", json.dumps(output_json) if
output_json is not None else None
- )
- return exception_json
- except Exception:
- raise error_response(
- f"Error executing method '{body.method}'.",
status.HTTP_500_INTERNAL_SERVER_ERROR
+ auth = request.headers.get("Authorization", "")
+ jwt_token_authorization(request.path, auth)
+ request_obj = WorkerStateBody(
+ state=body["state"], jobs_active=0, queues=body["queues"],
sysinfo=body["sysinfo"]
)
+ return register(worker_name, request_obj, session)
+ except HTTPException as e:
+ return e.to_response() # type: ignore[attr-defined]
-def edge_worker_api_v2(body: dict[str, Any]) -> APIResponse:
- """Handle Edge Worker API `/edge_worker/v1/rpcapi` endpoint for Airflow
2.10."""
- # Note: Except the method map this _was_ a 100% copy of internal API module
- #
airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api()
- # As of rework for FastAPI in Airflow 3.0, this is updated and to be
removed in future.
- from flask import Response, request
+@provide_session
+def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION)
-> Any:
+ """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint
for Airflow 2.10."""
+ from flask import request
try:
- json_request_headers(
- content_type=request.headers.get("Content-Type", ""),
accept=request.headers.get("Accept", "")
- )
-
auth = request.headers.get("Authorization", "")
- json_rpc = body.get("jsonrpc", "")
- method_name = body.get("method", "")
- request_obj = JsonRpcRequest(method=method_name, jsonrpc=json_rpc,
params=body.get("params"))
- jwt_token_authorization(request_obj, auth)
-
- json_rpc_version(request_obj)
-
- output_json = rpcapi(request_obj)
- response = json.dumps(output_json) if output_json is not None else None
- return Response(response=response, headers={"Content-Type":
"application/json"})
+ jwt_token_authorization(request.path, auth)
+ request_obj = WorkerStateBody(
+ state=body["state"],
+ jobs_active=body["jobs_active"],
+ queues=body["queues"],
+ sysinfo=body["sysinfo"],
+ )
+ return set_state(worker_name, request_obj, session)
except HTTPException as e:
return e.to_response() # type: ignore[attr-defined]
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/worker.py
b/providers/src/airflow/providers/edge/worker_api/routes/worker.py
new file mode 100644
index 00000000000..369ace0d2df
--- /dev/null
+++ b/providers/src/airflow/providers/edge/worker_api/routes/worker.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+from datetime import datetime
+from typing import Annotated
+
+from sqlalchemy import select
+
+from airflow.providers.edge.models.edge_worker import EdgeWorkerModel,
set_metrics
+from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest
+from airflow.providers.edge.worker_api.datamodels import (
+ WorkerQueueUpdateBody, # noqa: TC001
+ WorkerStateBody, # noqa: TC001
+)
+from airflow.providers.edge.worker_api.routes._v2_compat import (
+ AirflowRouter,
+ Body,
+ Depends,
+ HTTPException,
+ Path,
+ SessionDep,
+ create_openapi_http_exception_doc,
+ status,
+)
+from airflow.stats import Stats
+from airflow.utils import timezone
+
+worker_router = AirflowRouter(
+ tags=["Worker"],
+ prefix="/worker",
+ responses=create_openapi_http_exception_doc(
+ [
+ status.HTTP_400_BAD_REQUEST,
+ status.HTTP_403_FORBIDDEN,
+ ]
+ ),
+)
+
+
+def _assert_version(sysinfo: dict[str, str | int]) -> None:
+ """Check if the Edge Worker version matches the central API site."""
+ from airflow import __version__ as airflow_version
+ from airflow.providers.edge import __version__ as edge_provider_version
+
+ # Note: In future, more stable versions we might be more liberate, for the
+ # moment we require exact version match for Edge Worker and core
version
+ if "airflow_version" in sysinfo:
+ airflow_on_worker = sysinfo["airflow_version"]
+ if airflow_on_worker != airflow_version:
+ raise HTTPException(
+ status.HTTP_400_BAD_REQUEST,
+ f"Edge Worker runs on Airflow {airflow_on_worker} "
+ f"and the core runs on {airflow_version}. Rejecting access due
to difference.",
+ )
+ else:
+ raise HTTPException(
+ status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the
version it is running on."
+ )
+
+ if "edge_provider_version" in sysinfo:
+ provider_on_worker = sysinfo["edge_provider_version"]
+ if provider_on_worker != edge_provider_version:
+ raise HTTPException(
+ status.HTTP_400_BAD_REQUEST,
+ f"Edge Worker runs on Edge Provider {provider_on_worker} "
+ f"and the core runs on {edge_provider_version}. Rejecting
access due to difference.",
+ )
+ else:
+ raise HTTPException(
+ status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the
provider version it is running on."
+ )
+
+
+_worker_name_doc = Path(title="Worker Name", description="Hostname or instance
name of the worker")
+_worker_state_doc = Body(
+ title="Worker State",
+ description="State of the worker with details",
+ examples=[
+ {
+ "state": "running",
+ "jobs_active": 3,
+ "queues": ["large_node", "wisconsin_site"],
+ "sysinfo": {
+ "concurrency": 4,
+ "airflow_version": "2.10.0",
+ "edge_provider_version": "1.0.0",
+ },
+ }
+ ],
+)
+_worker_queue_doc = Body(
+ title="Changes in worker queues",
+ description="Changes to be applied to current queues of worker",
+ examples=[{"new_queues": ["new_queue"], "remove_queues": ["old_queue"]}],
+)
+
+
+@worker_router.post("/{worker_name}",
dependencies=[Depends(jwt_token_authorization_rest)])
+def register(
+ worker_name: Annotated[str, _worker_name_doc],
+ body: Annotated[WorkerStateBody, _worker_state_doc],
+ session: SessionDep,
+) -> datetime:
+ """Register a new worker to the backend."""
+ _assert_version(body.sysinfo)
+ query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ worker: EdgeWorkerModel = session.scalar(query)
+ if not worker:
+ worker = EdgeWorkerModel(worker_name=worker_name, state=body.state,
queues=body.queues)
+ worker.state = body.state
+ worker.queues = body.queues
+ worker.sysinfo = json.dumps(body.sysinfo)
+ worker.last_update = timezone.utcnow()
+ session.add(worker)
+ return worker.last_update
+
+
+@worker_router.patch("/{worker_name}",
dependencies=[Depends(jwt_token_authorization_rest)])
+def set_state(
+ worker_name: Annotated[str, _worker_name_doc],
+ body: Annotated[WorkerStateBody, _worker_state_doc],
+ session: SessionDep,
+) -> list[str] | None:
+ """Set state of worker and returns the current assigned queues."""
+ query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ worker: EdgeWorkerModel = session.scalar(query)
+ worker.state = body.state
+ worker.jobs_active = body.jobs_active
+ worker.sysinfo = json.dumps(body.sysinfo)
+ worker.last_update = timezone.utcnow()
+ session.commit()
+ Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
+ Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name":
worker_name})
+ set_metrics(
+ worker_name=worker_name,
+ state=body.state,
+ jobs_active=body.jobs_active,
+ concurrency=int(body.sysinfo.get("concurrency", -1)),
+ free_concurrency=int(body.sysinfo["free_concurrency"]),
+ queues=worker.queues,
+ )
+ _assert_version(body.sysinfo) # Exception only after worker state is in
the DB
+ return worker.queues
+
+
+@worker_router.patch(
+ "/queues/{worker_name}",
+ dependencies=[Depends(jwt_token_authorization_rest)],
+)
+def update_queues(
+ worker_name: Annotated[str, _worker_name_doc],
+ body: Annotated[WorkerQueueUpdateBody, _worker_queue_doc],
+ session: SessionDep,
+) -> None:
+ query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ worker: EdgeWorkerModel = session.scalar(query)
+ if body.new_queues:
+ worker.add_queues(body.new_queues)
+ if body.remove_queues:
+ worker.remove_queues(body.remove_queues)
+ session.add(worker)
diff --git a/providers/tests/edge/cli/test_edge_command.py
b/providers/tests/edge/cli/test_edge_command.py
index f6612b1a99a..3304831064a 100644
--- a/providers/tests/edge/cli/test_edge_command.py
+++ b/providers/tests/edge/cli/test_edge_command.py
@@ -29,7 +29,7 @@ import time_machine
from airflow.exceptions import AirflowException
from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job,
_write_pid_to_pidfile
from airflow.providers.edge.models.edge_job import EdgeJob
-from airflow.providers.edge.models.edge_worker import EdgeWorker,
EdgeWorkerState, EdgeWorkerVersionException
+from airflow.providers.edge.models.edge_worker import EdgeWorkerState,
EdgeWorkerVersionException
from airflow.utils.state import TaskInstanceState
from tests_common.test_utils.config import conf_vars
@@ -74,8 +74,16 @@ def
test_write_pid_to_pidfile_created_by_crashed_instance(tmp_path):
assert str(os.getpid()) == pid_file_path.read_text().strip()
-# Ignore the following error for mocking
-# mypy: disable-error-code="attr-defined"
+class _MockPopen(Popen):
+ def __init__(self, returncode=None):
+ self.generated_returncode = None
+
+ def poll(self):
+ pass
+
+ @property
+ def returncode(self):
+ return self.generated_returncode
class TestEdgeWorkerCli:
@@ -84,19 +92,6 @@ class TestEdgeWorkerCli:
logfile = tmp_path / "file.log"
logfile.touch()
- class MockPopen(Popen):
- generated_returncode = None
-
- def __init__(self):
- pass
-
- def poll(self):
- pass
-
- @property
- def returncode(self):
- return self.generated_returncode
-
return [
_Job(
edge_job=EdgeJob(
@@ -113,7 +108,7 @@ class TestEdgeWorkerCli:
edge_worker=None,
last_update=None,
),
- process=MockPopen(),
+ process=_MockPopen(),
logfile=logfile,
logsize=0,
),
@@ -168,7 +163,7 @@ class TestEdgeWorkerCli:
logfile_path_call_count, set_state_call_count = expected_calls
mock_reserve_task.side_effect = [reserve_result]
mock_popen.side_effect = ["dummy"]
- with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
got_job = worker_with_job.fetch_job()
mock_reserve_task.assert_called_once()
assert got_job == fetch_result
@@ -176,9 +171,8 @@ class TestEdgeWorkerCli:
assert mock_set_state.call_count == set_state_call_count
def test_check_running_jobs_running(self, worker_with_job: _EdgeWorkerCli):
- worker_with_job.jobs[0].process.generated_returncode = None
assert worker_with_job.free_concurrency == worker_with_job.concurrency
- with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
worker_with_job.check_running_jobs()
assert len(worker_with_job.jobs) == 1
assert (
@@ -189,8 +183,8 @@ class TestEdgeWorkerCli:
@patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
def test_check_running_jobs_success(self, mock_set_state, worker_with_job:
_EdgeWorkerCli):
job = worker_with_job.jobs[0]
- job.process.generated_returncode = 0
- with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ job.process.generated_returncode = 0 # type: ignore[attr-defined]
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
worker_with_job.check_running_jobs()
assert len(worker_with_job.jobs) == 0
mock_set_state.assert_called_once_with(job.edge_job.key,
TaskInstanceState.SUCCESS)
@@ -199,8 +193,8 @@ class TestEdgeWorkerCli:
@patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
def test_check_running_jobs_failed(self, mock_set_state, worker_with_job:
_EdgeWorkerCli):
job = worker_with_job.jobs[0]
- job.process.generated_returncode = 42
- with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ job.process.generated_returncode = 42 # type: ignore[attr-defined]
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
worker_with_job.check_running_jobs()
assert len(worker_with_job.jobs) == 0
mock_set_state.assert_called_once_with(job.edge_job.key,
TaskInstanceState.FAILED)
@@ -210,10 +204,12 @@ class TestEdgeWorkerCli:
@patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
def test_check_running_jobs_log_push(self, mock_push_logs,
worker_with_job: _EdgeWorkerCli):
job = worker_with_job.jobs[0]
- job.process.generated_returncode = None
job.logfile.write_text("some log content")
with conf_vars(
- {("edge", "api_url"): "https://mock.server", ("edge",
"push_log_chunk_size"): "524288"}
+ {
+ ("edge", "api_url"): "https://invalid-api-test-endpoint",
+ ("edge", "push_log_chunk_size"): "524288",
+ }
):
worker_with_job.check_running_jobs()
assert len(worker_with_job.jobs) == 1
@@ -225,12 +221,14 @@ class TestEdgeWorkerCli:
@patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
def test_check_running_jobs_log_push_increment(self, mock_push_logs,
worker_with_job: _EdgeWorkerCli):
job = worker_with_job.jobs[0]
- job.process.generated_returncode = None
job.logfile.write_text("hello ")
job.logsize = job.logfile.stat().st_size
job.logfile.write_text("hello world")
with conf_vars(
- {("edge", "api_url"): "https://mock.server", ("edge",
"push_log_chunk_size"): "524288"}
+ {
+ ("edge", "api_url"): "https://invalid-api-test-endpoint",
+ ("edge", "push_log_chunk_size"): "524288",
+ }
):
worker_with_job.check_running_jobs()
assert len(worker_with_job.jobs) == 1
@@ -242,9 +240,10 @@ class TestEdgeWorkerCli:
@patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
def test_check_running_jobs_log_push_chunks(self, mock_push_logs,
worker_with_job: _EdgeWorkerCli):
job = worker_with_job.jobs[0]
- job.process.generated_returncode = None
job.logfile.write_bytes("log1log2ülog3".encode("latin-1"))
- with conf_vars({("edge", "api_url"): "https://mock.server", ("edge",
"push_log_chunk_size"): "4"}):
+ with conf_vars(
+ {("edge", "api_url"): "https://invalid-api-test-endpoint",
("edge", "push_log_chunk_size"): "4"}
+ ):
worker_with_job.check_running_jobs()
assert len(worker_with_job.jobs) == 1
calls = mock_push_logs.call_args_list
@@ -262,13 +261,13 @@ class TestEdgeWorkerCli:
pytest.param(False, False, EdgeWorkerState.IDLE, id="idle"),
],
)
- @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+ @patch("airflow.providers.edge.cli.edge_command.worker_set_state")
def test_heartbeat(self, mock_set_state, drain, jobs, expected_state,
worker_with_job: _EdgeWorkerCli):
if not jobs:
worker_with_job.jobs = []
_EdgeWorkerCli.drain = drain
mock_set_state.return_value = ["queue1", "queue2"]
- with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
worker_with_job.heartbeat()
assert mock_set_state.call_args.args[1] == expected_state
queue_list = worker_with_job.queues or []
@@ -276,13 +275,13 @@ class TestEdgeWorkerCli:
assert "queue1" in (queue_list)
assert "queue2" in (queue_list)
- @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+ @patch("airflow.providers.edge.cli.edge_command.worker_set_state")
def test_version_mismatch(self, mock_set_state, worker_with_job):
mock_set_state.side_effect = EdgeWorkerVersionException("")
worker_with_job.heartbeat()
assert worker_with_job.drain
-
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+ @patch("airflow.providers.edge.cli.edge_command.worker_register")
def test_start_missing_apiserver(self, mock_register_worker,
worker_with_job: _EdgeWorkerCli):
mock_register_worker.side_effect = AirflowException(
"Something with 404:NOT FOUND means API is not active"
@@ -290,42 +289,28 @@ class TestEdgeWorkerCli:
with pytest.raises(SystemExit, match=r"API endpoint is not ready"):
worker_with_job.start()
-
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+ @patch("airflow.providers.edge.cli.edge_command.worker_register")
def test_start_server_error(self, mock_register_worker, worker_with_job:
_EdgeWorkerCli):
mock_register_worker.side_effect = AirflowException("Something other
error not FourhundretFour")
with pytest.raises(SystemExit, match=r"Something other"):
worker_with_job.start()
-
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+ @patch("airflow.providers.edge.cli.edge_command.worker_register")
@patch("airflow.providers.edge.cli.edge_command._EdgeWorkerCli.loop")
- @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+ @patch("airflow.providers.edge.cli.edge_command.worker_set_state")
def test_start_and_run_one(
- self, mock_set_state, mock_loop, mock_register_worker,
worker_with_job: _EdgeWorkerCli
+ self, mock_set_state, mock_loop, mock_register, worker_with_job:
_EdgeWorkerCli
):
- mock_register_worker.side_effect = [
- EdgeWorker(
- worker_name="test",
- state=EdgeWorkerState.STARTING,
- queues=None,
- first_online=datetime.now(),
- last_update=datetime.now(),
- jobs_active=0,
- jobs_taken=0,
- jobs_success=0,
- jobs_failed=0,
- sysinfo="",
- )
- ]
-
def stop_running():
_EdgeWorkerCli.drain = True
worker_with_job.jobs = []
mock_loop.side_effect = stop_running
+ mock_register.side_effect = [datetime.now()]
worker_with_job.start()
- mock_register_worker.assert_called_once()
+ mock_register.assert_called_once()
mock_loop.assert_called_once()
mock_set_state.assert_called_once()
diff --git a/providers/tests/edge/models/test_edge_worker.py
b/providers/tests/edge/worker_api/routes/test_worker.py
similarity index 67%
rename from providers/tests/edge/models/test_edge_worker.py
rename to providers/tests/edge/worker_api/routes/test_worker.py
index 20e394ffd57..e05a94c5f87 100644
--- a/providers/tests/edge/models/test_edge_worker.py
+++ b/providers/tests/edge/worker_api/routes/test_worker.py
@@ -22,11 +22,14 @@ from typing import TYPE_CHECKING
import pytest
from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli
-from airflow.providers.edge.models.edge_worker import (
- EdgeWorker,
- EdgeWorkerModel,
- EdgeWorkerState,
- EdgeWorkerVersionException,
+from airflow.providers.edge.models.edge_worker import EdgeWorkerModel,
EdgeWorkerState
+from airflow.providers.edge.worker_api.datamodels import
WorkerQueueUpdateBody, WorkerStateBody
+from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException
+from airflow.providers.edge.worker_api.routes.worker import (
+ _assert_version,
+ register,
+ set_state,
+ update_queues,
)
from airflow.utils import timezone
@@ -36,7 +39,7 @@ if TYPE_CHECKING:
pytestmark = pytest.mark.db_test
-class TestEdgeWorker:
+class TestWorkerApiRoutes:
@pytest.fixture
def cli_worker(self, tmp_path: Path) -> _EdgeWorkerCli:
test_worker = _EdgeWorkerCli(str(tmp_path / "dummy.pid"), "dummy",
None, 8, 5, 5)
@@ -50,28 +53,22 @@ class TestEdgeWorker:
from airflow import __version__ as airflow_version
from airflow.providers.edge import __version__ as edge_provider_version
- with pytest.raises(EdgeWorkerVersionException):
- EdgeWorker.assert_version({})
+ with pytest.raises(HTTPException):
+ _assert_version({})
- with pytest.raises(EdgeWorkerVersionException):
- EdgeWorker.assert_version({"airflow_version": airflow_version})
+ with pytest.raises(HTTPException):
+ _assert_version({"airflow_version": airflow_version})
- with pytest.raises(EdgeWorkerVersionException):
- EdgeWorker.assert_version({"edge_provider_version":
edge_provider_version})
+ with pytest.raises(HTTPException):
+ _assert_version({"edge_provider_version": edge_provider_version})
- with pytest.raises(EdgeWorkerVersionException):
- EdgeWorker.assert_version(
- {"airflow_version": "1.2.3", "edge_provider_version":
edge_provider_version}
- )
+ with pytest.raises(HTTPException):
+ _assert_version({"airflow_version": "1.2.3",
"edge_provider_version": edge_provider_version})
- with pytest.raises(EdgeWorkerVersionException):
- EdgeWorker.assert_version(
- {"airflow_version": airflow_version, "edge_provider_version":
"2023.10.07"}
- )
+ with pytest.raises(HTTPException):
+ _assert_version({"airflow_version": airflow_version,
"edge_provider_version": "2023.10.07"})
- EdgeWorker.assert_version(
- {"airflow_version": airflow_version, "edge_provider_version":
edge_provider_version}
- )
+ _assert_version({"airflow_version": airflow_version,
"edge_provider_version": edge_provider_version})
@pytest.mark.parametrize(
"input_queues",
@@ -80,12 +77,15 @@ class TestEdgeWorker:
pytest.param(["default", "default2"], id="with-queues"),
],
)
- def test_register_worker(
- self, session: Session, input_queues: list[str] | None, cli_worker:
_EdgeWorkerCli
- ):
- EdgeWorker.register_worker(
- "test_worker", EdgeWorkerState.STARTING, queues=input_queues,
sysinfo=cli_worker._get_sysinfo()
+ def test_register(self, session: Session, input_queues: list[str] | None,
cli_worker: _EdgeWorkerCli):
+ body = WorkerStateBody(
+ state=EdgeWorkerState.STARTING,
+ jobs_active=0,
+ queues=input_queues,
+ sysinfo=cli_worker._get_sysinfo(),
)
+ register("test_worker", body, session)
+ session.commit()
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
@@ -106,9 +106,13 @@ class TestEdgeWorker:
session.add(rwm)
session.commit()
- return_queues = EdgeWorker.set_state(
- "test2_worker", EdgeWorkerState.RUNNING, 1,
cli_worker._get_sysinfo()
+ body = WorkerStateBody(
+ state=EdgeWorkerState.RUNNING,
+ jobs_active=1,
+ queues=["default2"],
+ sysinfo=cli_worker._get_sysinfo(),
)
+ return_queues = set_state("test2_worker", body, session)
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
@@ -127,13 +131,12 @@ class TestEdgeWorker:
pytest.param(["init"], None, ["init"], id="check-duplicated"),
],
)
- def test_add_and_remove_queues(
+ def test_update_queues(
self,
session: Session,
add_queues: list[str] | None,
remove_queues: list[str] | None,
expected_queues: list[str],
- cli_worker: _EdgeWorkerCli,
):
rwm = EdgeWorkerModel(
worker_name="test2_worker",
@@ -143,7 +146,8 @@ class TestEdgeWorker:
)
session.add(rwm)
session.commit()
- EdgeWorker.add_and_remove_queues("test2_worker", add_queues,
remove_queues, session)
+ body = WorkerQueueUpdateBody(new_queues=add_queues,
remove_queues=remove_queues)
+ update_queues("test2_worker", body, session)
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"