This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3390bfbf98 AIP-69: Add CLI to Edge Provider (#42050)
3390bfbf98 is described below

commit 3390bfbf98c4ea4324ebfc16bd04e84e66daf73f
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Sep 24 22:49:36 2024 +0200

    AIP-69: Add CLI to Edge Provider (#42050)
    
    * Add CLI to Edge Provider
    
    * Review feedback
---
 airflow/providers/edge/cli/__init__.py          |  16 ++
 airflow/providers/edge/cli/edge_command.py      | 313 ++++++++++++++++++++++++
 tests/providers/edge/cli/__init__.py            |  17 ++
 tests/providers/edge/cli/test_edge_command.py   | 259 ++++++++++++++++++++
 tests/providers/edge/models/test_edge_worker.py |  29 +++
 5 files changed, 634 insertions(+)

diff --git a/airflow/providers/edge/cli/__init__.py 
b/airflow/providers/edge/cli/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/edge/cli/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/edge/cli/edge_command.py 
b/airflow/providers/edge/cli/edge_command.py
new file mode 100644
index 0000000000..09998ffe80
--- /dev/null
+++ b/airflow/providers/edge/cli/edge_command.py
@@ -0,0 +1,313 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import platform
+import signal
+import sys
+from dataclasses import dataclass
+from datetime import datetime
+from pathlib import Path
+from subprocess import Popen
+from time import sleep
+
+import psutil
+from lockfile.pidlockfile import read_pid_from_pidfile, 
remove_existing_pidfile, write_pid_to_pidfile
+
+from airflow import __version__ as airflow_version, settings
+from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.edge import __version__ as edge_provider_version
+from airflow.providers.edge.models.edge_job import EdgeJob
+from airflow.providers.edge.models.edge_logs import EdgeLogs
+from airflow.providers.edge.models.edge_worker import EdgeWorker, 
EdgeWorkerState
+from airflow.utils import cli as cli_utils
+from airflow.utils.platform import IS_WINDOWS
+from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
+from airflow.utils.state import TaskInstanceState
+
+logger = logging.getLogger(__name__)
+EDGE_WORKER_PROCESS_NAME = "edge-worker"
+EDGE_WORKER_HEADER = "\n".join(
+    [
+        r"   ____   __           _      __         __",
+        r"  / __/__/ /__ ____   | | /| / /__  ____/ /_____ ____",
+        r" / _// _  / _ `/ -_)  | |/ |/ / _ \/ __/  '_/ -_) __/",
+        r"/___/\_,_/\_, /\__/   |__/|__/\___/_/ /_/\_\\__/_/",
+        r"         /___/",
+        r"",
+    ]
+)
+
+
+@providers_configuration_loaded
+def force_use_internal_api_on_edge_worker():
+    """
+    Ensure that the environment is configured for the internal API without 
needing to declare it outside.
+
+    This is only required for an Edge worker and must to be done before the 
Click CLI wrapper is initiated.
+    That is because the CLI wrapper will attempt to establish a DB connection, 
which will fail before the
+    function call can take effect. In an Edge worker, we need to "patch" the 
environment before starting.
+    """
+    if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]:
+        api_url = conf.get("edge", "api_url")
+        if not api_url:
+            raise SystemExit("Error: API URL is not configured, please correct 
configuration.")
+        logger.info("Starting worker with API endpoint %s", api_url)
+        # export Edge API to be used for internal API
+        os.environ["AIRFLOW_ENABLE_AIP_44"] = "True"
+        os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url
+        InternalApiConfig.set_use_internal_api("edge-worker")
+        # Disable mini-scheduler post task execution and leave next task 
schedule to core scheduler
+        os.environ["AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION"] = 
"False"
+
+
+force_use_internal_api_on_edge_worker()
+
+
+def _hostname() -> str:
+    if IS_WINDOWS:
+        return platform.uname().node
+    else:
+        return os.uname()[1]
+
+
+def _get_sysinfo() -> dict:
+    """Produce the sysinfo from worker to post to central site."""
+    return {
+        "airflow_version": airflow_version,
+        "edge_provider_version": edge_provider_version,
+    }
+
+
+def _pid_file_path(pid_file: str | None) -> str:
+    return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, 
pid=pid_file)[0]
+
+
+@dataclass
+class _Job:
+    """Holds all information for a task/job to be executed as bundle."""
+
+    edge_job: EdgeJob
+    process: Popen
+    logfile: Path
+    logsize: int
+    """Last size of log file, point of last chunk push."""
+
+
+class _EdgeWorkerCli:
+    """Runner instance which executes the Edge Worker."""
+
+    jobs: list[_Job] = []
+    """List of jobs that the worker is running currently."""
+    last_hb: datetime | None = None
+    """Timestamp of last heart beat sent to server."""
+    drain: bool = False
+    """Flag if job processing should be completed and no new jobs fetched for 
a graceful stop/shutdown."""
+
+    def __init__(
+        self,
+        pid_file_path: Path,
+        hostname: str,
+        queues: list[str] | None,
+        concurrency: int,
+        job_poll_interval: int,
+        heartbeat_interval: int,
+    ):
+        self.pid_file_path = pid_file_path
+        self.job_poll_interval = job_poll_interval
+        self.hb_interval = heartbeat_interval
+        self.hostname = hostname
+        self.queues = queues
+        self.concurrency = concurrency
+
+    @staticmethod
+    def signal_handler(sig, frame):
+        logger.info("Request to show down Edge Worker received, waiting for 
jobs to complete.")
+        _EdgeWorkerCli.drain = True
+
+    def start(self):
+        """Start the execution in a loop until terminated."""
+        try:
+            self.last_hb = EdgeWorker.register_worker(
+                self.hostname, EdgeWorkerState.STARTING, self.queues, 
_get_sysinfo()
+            ).last_update
+        except AirflowException as e:
+            if "404:NOT FOUND" in str(e):
+                raise SystemExit("Error: API endpoint is not ready, please set 
[edge] api_enabled=True.")
+            raise SystemExit(str(e))
+        write_pid_to_pidfile(self.pid_file_path)
+        signal.signal(signal.SIGINT, _EdgeWorkerCli.signal_handler)
+        try:
+            while not _EdgeWorkerCli.drain or self.jobs:
+                self.loop()
+
+            logger.info("Quitting worker, signal being offline.")
+            EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, 
_get_sysinfo())
+        finally:
+            remove_existing_pidfile(self.pid_file_path)
+
+    def loop(self):
+        """Run a loop of scheduling and monitoring tasks."""
+        new_job = False
+        if not _EdgeWorkerCli.drain and len(self.jobs) < self.concurrency:
+            new_job = self.fetch_job()
+        self.check_running_jobs()
+
+        if _EdgeWorkerCli.drain or datetime.now().timestamp() - 
self.last_hb.timestamp() > self.hb_interval:
+            self.heartbeat()
+            self.last_hb = datetime.now()
+
+        if not new_job:
+            self.interruptible_sleep()
+
+    def fetch_job(self) -> bool:
+        """Fetch and start a new job from central site."""
+        logger.debug("Attempting to fetch a new job...")
+        edge_job = EdgeJob.reserve_task(self.hostname, self.queues)
+        if edge_job:
+            logger.info("Received job: %s", edge_job)
+            env = os.environ.copy()
+            env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
+            env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", 
"api_url")
+            env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
+            process = Popen(edge_job.command, close_fds=True, env=env)
+            logfile = EdgeLogs.logfile_path(edge_job.key)
+            self.jobs.append(_Job(edge_job, process, logfile, 0))
+            EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING)
+            return True
+
+        logger.info("No new job to process%s", f", {len(self.jobs)} still 
running" if self.jobs else "")
+        return False
+
+    def check_running_jobs(self) -> None:
+        """Check which of the running tasks/jobs are completed and report 
back."""
+        for i in range(len(self.jobs) - 1, -1, -1):
+            job = self.jobs[i]
+            job.process.poll()
+            if job.process.returncode is not None:
+                self.jobs.remove(job)
+                if job.process.returncode == 0:
+                    logger.info("Job completed: %s", job.edge_job)
+                    EdgeJob.set_state(job.edge_job.key, 
TaskInstanceState.SUCCESS)
+                else:
+                    logger.error("Job failed: %s", job.edge_job)
+                    EdgeJob.set_state(job.edge_job.key, 
TaskInstanceState.FAILED)
+            if job.logfile.exists() and job.logfile.stat().st_size > 
job.logsize:
+                with job.logfile.open("r") as logfile:
+                    logfile.seek(job.logsize, os.SEEK_SET)
+                    logdata = logfile.read()
+                    EdgeLogs.push_logs(
+                        task=job.edge_job.key,
+                        log_chunk_time=datetime.now(),
+                        log_chunk_data=logdata,
+                    )
+                    job.logsize += len(logdata)
+
+    def heartbeat(self) -> None:
+        """Report liveness state of worker to central site with stats."""
+        state = (
+            (EdgeWorkerState.TERMINATING if _EdgeWorkerCli.drain else 
EdgeWorkerState.RUNNING)
+            if self.jobs
+            else EdgeWorkerState.IDLE
+        )
+        sysinfo = _get_sysinfo()
+        EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo)
+
+    def interruptible_sleep(self):
+        """Sleeps but stops sleeping if drain is made."""
+        drain_before_sleep = _EdgeWorkerCli.drain
+        for _ in range(0, self.job_poll_interval * 10):
+            sleep(0.1)
+            if drain_before_sleep != _EdgeWorkerCli.drain:
+                return
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def worker(args):
+    """Start Airflow Edge Worker."""
+    print(settings.HEADER)
+    print(EDGE_WORKER_HEADER)
+
+    edge_worker = _EdgeWorkerCli(
+        pid_file_path=_pid_file_path(args.pid),
+        hostname=args.edge_hostname or _hostname(),
+        queues=args.queues.split(",") if args.queues else None,
+        concurrency=args.concurrency,
+        job_poll_interval=conf.getint("edge", "job_poll_interval"),
+        heartbeat_interval=conf.getint("edge", "heartbeat_interval"),
+    )
+    edge_worker.start()
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def stop(args):
+    """Stop a running Airflow Edge Worker."""
+    pid = read_pid_from_pidfile(_pid_file_path(args.pid))
+    # Send SIGINT
+    if pid:
+        logger.warning("Sending SIGINT to worker pid %i.", pid)
+        worker_process = psutil.Process(pid)
+        worker_process.send_signal(signal.SIGINT)
+    else:
+        logger.warning("Could not find PID of worker.")
+
+
+ARG_CONCURRENCY = Arg(
+    ("-c", "--concurrency"),
+    type=int,
+    help="The number of worker processes",
+    default=conf.getint("edge", "worker_concurrency", fallback=8),
+)
+ARG_QUEUES = Arg(
+    ("-q", "--queues"),
+    help="Comma delimited list of queues to serve, serve all queues if not 
provided.",
+)
+ARG_EDGE_HOSTNAME = Arg(
+    ("-H", "--edge-hostname"),
+    help="Set the hostname of worker if you have multiple workers on a single 
machine",
+)
+EDGE_COMMANDS: list[ActionCommand] = [
+    ActionCommand(
+        name=worker.__name__,
+        help=worker.__doc__,
+        func=worker,
+        args=(
+            ARG_CONCURRENCY,
+            ARG_QUEUES,
+            ARG_EDGE_HOSTNAME,
+            ARG_PID,
+            ARG_VERBOSE,
+        ),
+    ),
+    ActionCommand(
+        name=stop.__name__,
+        help=stop.__doc__,
+        func=stop,
+        args=(
+            ARG_PID,
+            ARG_VERBOSE,
+        ),
+    ),
+]
diff --git a/tests/providers/edge/cli/__init__.py 
b/tests/providers/edge/cli/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/providers/edge/cli/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/edge/cli/test_edge_command.py 
b/tests/providers/edge/cli/test_edge_command.py
new file mode 100644
index 0000000000..398c221db0
--- /dev/null
+++ b/tests/providers/edge/cli/test_edge_command.py
@@ -0,0 +1,259 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+from pathlib import Path
+from subprocess import Popen
+from unittest.mock import patch
+
+import pytest
+import time_machine
+
+from airflow.exceptions import AirflowException
+from airflow.providers.edge.cli.edge_command import (
+    _EdgeWorkerCli,
+    _get_sysinfo,
+    _Job,
+)
+from airflow.providers.edge.models.edge_job import EdgeJob
+from airflow.providers.edge.models.edge_worker import EdgeWorker, 
EdgeWorkerState
+from airflow.utils.state import TaskInstanceState
+from tests.test_utils.config import conf_vars
+
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+# Ignore the following error for mocking
+# mypy: disable-error-code="attr-defined"
+
+
+def test_get_sysinfo():
+    sysinfo = _get_sysinfo()
+    assert "airflow_version" in sysinfo
+    assert "edge_provider_version" in sysinfo
+
+
+class TestEdgeWorkerCli:
+    @pytest.fixture
+    def dummy_joblist(self, tmp_path: Path) -> list[_Job]:
+        logfile = tmp_path / "file.log"
+        logfile.touch()
+
+        class MockPopen(Popen):
+            generated_returncode = None
+
+            def __init__(self):
+                pass
+
+            def poll(self):
+                pass
+
+            @property
+            def returncode(self):
+                return self.generated_returncode
+
+        return [
+            _Job(
+                edge_job=EdgeJob(
+                    dag_id="test",
+                    task_id="test1",
+                    run_id="test",
+                    map_index=-1,
+                    try_number=1,
+                    state=TaskInstanceState.RUNNING,
+                    queue="test",
+                    command=["test", "command"],
+                    queued_dttm=datetime.now(),
+                    edge_worker=None,
+                    last_update=None,
+                ),
+                process=MockPopen(),
+                logfile=logfile,
+                logsize=0,
+            ),
+        ]
+
+    @pytest.fixture
+    def worker_with_job(self, tmp_path: Path, dummy_joblist: list[_Job]) -> 
_EdgeWorkerCli:
+        test_worker = _EdgeWorkerCli(tmp_path / "dummy.pid", "dummy", None, 8, 
5, 5)
+        test_worker.jobs = dummy_joblist
+        return test_worker
+
+    @pytest.mark.parametrize(
+        "reserve_result, fetch_result, expected_calls",
+        [
+            pytest.param(None, False, (0, 0), id="no_job"),
+            pytest.param(
+                EdgeJob(
+                    dag_id="test",
+                    task_id="test",
+                    run_id="test",
+                    map_index=-1,
+                    try_number=1,
+                    state=TaskInstanceState.QUEUED,
+                    queue="test",
+                    command=["test", "command"],
+                    queued_dttm=datetime.now(),
+                    edge_worker=None,
+                    last_update=None,
+                ),
+                True,
+                (1, 1),
+                id="new_job",
+            ),
+        ],
+    )
+    @patch("airflow.providers.edge.models.edge_job.EdgeJob.reserve_task")
+    @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.logfile_path")
+    @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+    @patch("subprocess.Popen")
+    def test_fetch_job(
+        self,
+        mock_popen,
+        mock_set_state,
+        mock_logfile_path,
+        mock_reserve_task,
+        reserve_result,
+        fetch_result,
+        expected_calls,
+        worker_with_job: _EdgeWorkerCli,
+    ):
+        logfile_path_call_count, set_state_call_count = expected_calls
+        mock_reserve_task.side_effect = [reserve_result]
+        mock_popen.side_effect = ["dummy"]
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            got_job = worker_with_job.fetch_job()
+        mock_reserve_task.assert_called_once()
+        assert got_job == fetch_result
+        assert mock_logfile_path.call_count == logfile_path_call_count
+        assert mock_set_state.call_count == set_state_call_count
+
+    def test_check_running_jobs_running(self, worker_with_job: _EdgeWorkerCli):
+        worker_with_job.jobs[0].process.generated_returncode = None
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            worker_with_job.check_running_jobs()
+        assert len(worker_with_job.jobs) == 1
+
+    @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+    def test_check_running_jobs_success(self, mock_set_state, worker_with_job: 
_EdgeWorkerCli):
+        job = worker_with_job.jobs[0]
+        job.process.generated_returncode = 0
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            worker_with_job.check_running_jobs()
+        assert len(worker_with_job.jobs) == 0
+        mock_set_state.assert_called_once_with(job.edge_job.key, 
TaskInstanceState.SUCCESS)
+
+    @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+    def test_check_running_jobs_failed(self, mock_set_state, worker_with_job: 
_EdgeWorkerCli):
+        job = worker_with_job.jobs[0]
+        job.process.generated_returncode = 42
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            worker_with_job.check_running_jobs()
+        assert len(worker_with_job.jobs) == 0
+        mock_set_state.assert_called_once_with(job.edge_job.key, 
TaskInstanceState.FAILED)
+
+    @time_machine.travel(datetime.now(), tick=False)
+    @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
+    def test_check_running_jobs_log_push(self, mock_push_logs, 
worker_with_job: _EdgeWorkerCli):
+        job = worker_with_job.jobs[0]
+        job.process.generated_returncode = None
+        job.logfile.write_text("some log content")
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            worker_with_job.check_running_jobs()
+        assert len(worker_with_job.jobs) == 1
+        mock_push_logs.assert_called_once_with(
+            task=job.edge_job.key, log_chunk_time=datetime.now(), 
log_chunk_data="some log content"
+        )
+
+    @time_machine.travel(datetime.now(), tick=False)
+    @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
+    def test_check_running_jobs_log_push_increment(self, mock_push_logs, 
worker_with_job: _EdgeWorkerCli):
+        job = worker_with_job.jobs[0]
+        job.process.generated_returncode = None
+        job.logfile.write_text("hello ")
+        job.logsize = job.logfile.stat().st_size
+        job.logfile.write_text("hello world")
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            worker_with_job.check_running_jobs()
+        assert len(worker_with_job.jobs) == 1
+        mock_push_logs.assert_called_once_with(
+            task=job.edge_job.key, log_chunk_time=datetime.now(), 
log_chunk_data="world"
+        )
+
+    @pytest.mark.parametrize(
+        "drain, jobs, expected_state",
+        [
+            pytest.param(False, True, EdgeWorkerState.RUNNING, 
id="running_jobs"),
+            pytest.param(True, True, EdgeWorkerState.TERMINATING, 
id="shutting_down"),
+            pytest.param(False, False, EdgeWorkerState.IDLE, id="idle"),
+        ],
+    )
+    @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+    def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, 
worker_with_job: _EdgeWorkerCli):
+        if not jobs:
+            worker_with_job.jobs = []
+        _EdgeWorkerCli.drain = drain
+        with conf_vars({("edge", "api_url"): "https://mock.server"}):
+            worker_with_job.heartbeat()
+        assert mock_set_state.call_args.args[1] == expected_state
+
+    
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+    def test_start_missing_apiserver(self, mock_register_worker, 
worker_with_job: _EdgeWorkerCli):
+        mock_register_worker.side_effect = AirflowException(
+            "Something with 404:NOT FOUND means API is not active"
+        )
+        with pytest.raises(SystemExit, match=r"API endpoint is not ready"):
+            worker_with_job.start()
+
+    
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+    def test_start_server_error(self, mock_register_worker, worker_with_job: 
_EdgeWorkerCli):
+        mock_register_worker.side_effect = AirflowException("Something other 
error not FourhundretFour")
+        with pytest.raises(SystemExit, match=r"Something other"):
+            worker_with_job.start()
+
+    
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+    @patch("airflow.providers.edge.cli.edge_command._EdgeWorkerCli.loop")
+    @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+    def test_start_and_run_one(
+        self, mock_set_state, mock_loop, mock_register_worker, 
worker_with_job: _EdgeWorkerCli
+    ):
+        mock_register_worker.side_effect = [
+            EdgeWorker(
+                worker_name="test",
+                state=EdgeWorkerState.STARTING,
+                queues=None,
+                first_online=datetime.now(),
+                last_update=datetime.now(),
+                jobs_active=0,
+                jobs_taken=0,
+                jobs_success=0,
+                jobs_failed=0,
+                sysinfo="",
+            )
+        ]
+
+        def stop_running():
+            _EdgeWorkerCli.drain = True
+            worker_with_job.jobs = []
+
+        mock_loop.side_effect = stop_running
+
+        worker_with_job.start()
+
+        mock_register_worker.assert_called_once()
+        mock_loop.assert_called_once()
+        mock_set_state.assert_called_once()
diff --git a/tests/providers/edge/models/test_edge_worker.py 
b/tests/providers/edge/models/test_edge_worker.py
index 9eca293baf..f0e0ac9dfa 100644
--- a/tests/providers/edge/models/test_edge_worker.py
+++ b/tests/providers/edge/models/test_edge_worker.py
@@ -20,11 +20,14 @@ from typing import TYPE_CHECKING
 
 import pytest
 
+from airflow.providers.edge.cli.edge_command import _get_sysinfo
 from airflow.providers.edge.models.edge_worker import (
     EdgeWorker,
     EdgeWorkerModel,
+    EdgeWorkerState,
     EdgeWorkerVersionException,
 )
+from airflow.utils import timezone
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -63,3 +66,29 @@ class TestEdgeWorker:
         EdgeWorker.assert_version(
             {"airflow_version": airflow_version, "edge_provider_version": 
edge_provider_version}
         )
+
+    def test_register_worker(self, session: Session):
+        EdgeWorker.register_worker(
+            "test_worker", EdgeWorkerState.STARTING, queues=None, 
sysinfo=_get_sysinfo()
+        )
+
+        worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
+        assert len(worker) == 1
+        assert worker[0].worker_name == "test_worker"
+
+    def test_set_state(self, session: Session):
+        rwm = EdgeWorkerModel(
+            worker_name="test2_worker",
+            state=EdgeWorkerState.IDLE,
+            queues=["default"],
+            first_online=timezone.utcnow(),
+        )
+        session.add(rwm)
+        session.commit()
+
+        EdgeWorker.set_state("test2_worker", EdgeWorkerState.RUNNING, 1, 
_get_sysinfo())
+
+        worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
+        assert len(worker) == 1
+        assert worker[0].worker_name == "test2_worker"
+        assert worker[0].state == EdgeWorkerState.RUNNING

Reply via email to