jscheffl commented on code in PR #44311:
URL: https://github.com/apache/airflow/pull/44311#discussion_r1863907148


##########
providers/src/airflow/providers/edge/worker_api/routes/worker.py:
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+from datetime import datetime
+from typing import Annotated
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session  # noqa: TCH002
+
+from airflow.providers.edge.models.edge_worker import EdgeWorkerModel, 
set_metrics
+from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest
+from airflow.providers.edge.worker_api.datamodels import (
+    WorkerQueueUpdateBody,  # noqa: TC001
+    WorkerStateBody,  # noqa: TC001
+)
+from airflow.providers.edge.worker_api.routes._v2_compat import (
+    AirflowRouter,
+    Body,
+    Depends,
+    HTTPException,
+    Path,
+    create_openapi_http_exception_doc,
+    get_session,
+    status,
+)
+from airflow.stats import Stats
+from airflow.utils import timezone
+
+worker_router = AirflowRouter(tags=["Worker"], prefix="/worker")
+
+
+def _assert_version(sysinfo: dict[str, str | int]) -> None:
+    """Check if the Edge Worker version matches the central API site."""
+    from airflow import __version__ as airflow_version
+    from airflow.providers.edge import __version__ as edge_provider_version
+
+    # Note: In future, more stable versions we might be more liberate, for the
+    #       moment we require exact version match for Edge Worker and core 
version
+    if "airflow_version" in sysinfo:
+        airflow_on_worker = sysinfo["airflow_version"]
+        if airflow_on_worker != airflow_version:
+            raise HTTPException(
+                status.HTTP_400_BAD_REQUEST,
+                f"Edge Worker runs on Airflow {airflow_on_worker} "
+                f"and the core runs on {airflow_version}. Rejecting access due 
to difference.",
+            )
+    else:
+        raise HTTPException(
+            status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the 
version it is running on."
+        )
+
+    if "edge_provider_version" in sysinfo:
+        provider_on_worker = sysinfo["edge_provider_version"]
+        if provider_on_worker != edge_provider_version:
+            raise HTTPException(
+                status.HTTP_400_BAD_REQUEST,
+                f"Edge Worker runs on Edge Provider {provider_on_worker} "
+                f"and the core runs on {edge_provider_version}. Rejecting 
access due to difference.",
+            )
+    else:
+        raise HTTPException(
+            status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the 
provider version it is running on."
+        )
+
+
+_worker_name_doc = Path(title="Worker Name", description="Hostname or instance 
name of the worker")
+_worker_state_doc = Body(
+    title="Worker State",
+    description="State of the worker with details",
+    examples=[
+        {
+            "state": "running",
+            "jobs_active": 3,
+            "queues": ["large_node", "wisconsin_site"],
+            "sysinfo": {
+                "concurrency": 4,
+                "airflow_version": "2.10.0",
+                "edge_provider_version": "1.0.0",
+            },
+        }
+    ],
+)
+_worker_queue_doc = Body(
+    title="Changes in worker queues",
+    description="Changes to be applied to current queues of worker",
+    examples=[{"new_queues": ["new_queue"], "remove_queues": ["old_queue"]}],
+)
+
+
+@worker_router.post(
+    "/{worker_name}",
+    dependencies=[Depends(jwt_token_authorization_rest)],
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_403_FORBIDDEN,
+        ]
+    ),
+)
+def register(
+    worker_name: Annotated[str, _worker_name_doc],
+    body: Annotated[WorkerStateBody, _worker_state_doc],
+    session: Annotated[Session, Depends(get_session)],
+) -> datetime:
+    """Register a new worker to the backend."""
+    _assert_version(body.sysinfo)
+    query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
+    worker: EdgeWorkerModel = session.scalar(query)
+    if not worker:
+        worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, 
queues=body.queues)
+    worker.state = body.state
+    worker.queues = body.queues
+    worker.sysinfo = json.dumps(body.sysinfo)
+    worker.last_update = timezone.utcnow()
+    session.add(worker)
+    return worker.last_update
+
+
+@worker_router.patch(
+    "/{worker_name}",
+    dependencies=[Depends(jwt_token_authorization_rest)],
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_403_FORBIDDEN,
+        ]
+    ),
+)
+def set_state(
+    worker_name: Annotated[str, _worker_name_doc],
+    body: Annotated[WorkerStateBody, _worker_state_doc],
+    session: Annotated[Session, Depends(get_session)],
+) -> list[str] | None:
+    """Set state of worker and returns the current assigned queues."""
+    query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
+    worker: EdgeWorkerModel = session.scalar(query)
+    worker.state = body.state
+    worker.jobs_active = body.jobs_active
+    worker.sysinfo = json.dumps(body.sysinfo)
+    worker.last_update = timezone.utcnow()
+    session.commit()
+    Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
+    Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": 
worker_name})
+    set_metrics(
+        worker_name=worker_name,
+        state=body.state,
+        jobs_active=body.jobs_active,
+        concurrency=int(body.sysinfo.get("concurrency", -1)),
+        free_concurrency=int(body.sysinfo["free_concurrency"]),
+        queues=worker.queues,
+    )
+    _assert_version(body.sysinfo)  #  Exception only after worker state is in 
the DB
+    return worker.queues

Review Comment:
   I do not understand what you mean. `_assert_version()` will raise HTTP 400 
if not good, otherwise the return value is `list[str] | None` as in the 
signature. What do you think is missing?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to