This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3390bfbf98 AIP-69: Add CLI to Edge Provider (#42050)
3390bfbf98 is described below
commit 3390bfbf98c4ea4324ebfc16bd04e84e66daf73f
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Sep 24 22:49:36 2024 +0200
AIP-69: Add CLI to Edge Provider (#42050)
* Add CLI to Edge Provider
* Review feedback
---
airflow/providers/edge/cli/__init__.py | 16 ++
airflow/providers/edge/cli/edge_command.py | 313 ++++++++++++++++++++++++
tests/providers/edge/cli/__init__.py | 17 ++
tests/providers/edge/cli/test_edge_command.py | 259 ++++++++++++++++++++
tests/providers/edge/models/test_edge_worker.py | 29 +++
5 files changed, 634 insertions(+)
diff --git a/airflow/providers/edge/cli/__init__.py
b/airflow/providers/edge/cli/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/edge/cli/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/edge/cli/edge_command.py
b/airflow/providers/edge/cli/edge_command.py
new file mode 100644
index 0000000000..09998ffe80
--- /dev/null
+++ b/airflow/providers/edge/cli/edge_command.py
@@ -0,0 +1,313 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import os
+import platform
+import signal
+import sys
+from dataclasses import dataclass
+from datetime import datetime
+from pathlib import Path
+from subprocess import Popen
+from time import sleep
+
+import psutil
+from lockfile.pidlockfile import read_pid_from_pidfile,
remove_existing_pidfile, write_pid_to_pidfile
+
+from airflow import __version__ as airflow_version, settings
+from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.edge import __version__ as edge_provider_version
+from airflow.providers.edge.models.edge_job import EdgeJob
+from airflow.providers.edge.models.edge_logs import EdgeLogs
+from airflow.providers.edge.models.edge_worker import EdgeWorker,
EdgeWorkerState
+from airflow.utils import cli as cli_utils
+from airflow.utils.platform import IS_WINDOWS
+from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
+from airflow.utils.state import TaskInstanceState
+
+logger = logging.getLogger(__name__)
+EDGE_WORKER_PROCESS_NAME = "edge-worker"
+EDGE_WORKER_HEADER = "\n".join(
+ [
+ r" ____ __ _ __ __",
+ r" / __/__/ /__ ____ | | /| / /__ ____/ /_____ ____",
+ r" / _// _ / _ `/ -_) | |/ |/ / _ \/ __/ '_/ -_) __/",
+ r"/___/\_,_/\_, /\__/ |__/|__/\___/_/ /_/\_\\__/_/",
+ r" /___/",
+ r"",
+ ]
+)
+
+
+@providers_configuration_loaded
+def force_use_internal_api_on_edge_worker():
+ """
+ Ensure that the environment is configured for the internal API without
needing to declare it outside.
+
+ This is only required for an Edge worker and must to be done before the
Click CLI wrapper is initiated.
+ That is because the CLI wrapper will attempt to establish a DB connection,
which will fail before the
+ function call can take effect. In an Edge worker, we need to "patch" the
environment before starting.
+ """
+ if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]:
+ api_url = conf.get("edge", "api_url")
+ if not api_url:
+ raise SystemExit("Error: API URL is not configured, please correct
configuration.")
+ logger.info("Starting worker with API endpoint %s", api_url)
+ # export Edge API to be used for internal API
+ os.environ["AIRFLOW_ENABLE_AIP_44"] = "True"
+ os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url
+ InternalApiConfig.set_use_internal_api("edge-worker")
+ # Disable mini-scheduler post task execution and leave next task
schedule to core scheduler
+ os.environ["AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION"] =
"False"
+
+
+force_use_internal_api_on_edge_worker()
+
+
+def _hostname() -> str:
+ if IS_WINDOWS:
+ return platform.uname().node
+ else:
+ return os.uname()[1]
+
+
+def _get_sysinfo() -> dict:
+ """Produce the sysinfo from worker to post to central site."""
+ return {
+ "airflow_version": airflow_version,
+ "edge_provider_version": edge_provider_version,
+ }
+
+
+def _pid_file_path(pid_file: str | None) -> str:
+ return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME,
pid=pid_file)[0]
+
+
+@dataclass
+class _Job:
+ """Holds all information for a task/job to be executed as bundle."""
+
+ edge_job: EdgeJob
+ process: Popen
+ logfile: Path
+ logsize: int
+ """Last size of log file, point of last chunk push."""
+
+
+class _EdgeWorkerCli:
+ """Runner instance which executes the Edge Worker."""
+
+ jobs: list[_Job] = []
+ """List of jobs that the worker is running currently."""
+ last_hb: datetime | None = None
+ """Timestamp of last heart beat sent to server."""
+ drain: bool = False
+ """Flag if job processing should be completed and no new jobs fetched for
a graceful stop/shutdown."""
+
+ def __init__(
+ self,
+ pid_file_path: Path,
+ hostname: str,
+ queues: list[str] | None,
+ concurrency: int,
+ job_poll_interval: int,
+ heartbeat_interval: int,
+ ):
+ self.pid_file_path = pid_file_path
+ self.job_poll_interval = job_poll_interval
+ self.hb_interval = heartbeat_interval
+ self.hostname = hostname
+ self.queues = queues
+ self.concurrency = concurrency
+
+ @staticmethod
+ def signal_handler(sig, frame):
+ logger.info("Request to show down Edge Worker received, waiting for
jobs to complete.")
+ _EdgeWorkerCli.drain = True
+
+ def start(self):
+ """Start the execution in a loop until terminated."""
+ try:
+ self.last_hb = EdgeWorker.register_worker(
+ self.hostname, EdgeWorkerState.STARTING, self.queues,
_get_sysinfo()
+ ).last_update
+ except AirflowException as e:
+ if "404:NOT FOUND" in str(e):
+ raise SystemExit("Error: API endpoint is not ready, please set
[edge] api_enabled=True.")
+ raise SystemExit(str(e))
+ write_pid_to_pidfile(self.pid_file_path)
+ signal.signal(signal.SIGINT, _EdgeWorkerCli.signal_handler)
+ try:
+ while not _EdgeWorkerCli.drain or self.jobs:
+ self.loop()
+
+ logger.info("Quitting worker, signal being offline.")
+ EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0,
_get_sysinfo())
+ finally:
+ remove_existing_pidfile(self.pid_file_path)
+
+ def loop(self):
+ """Run a loop of scheduling and monitoring tasks."""
+ new_job = False
+ if not _EdgeWorkerCli.drain and len(self.jobs) < self.concurrency:
+ new_job = self.fetch_job()
+ self.check_running_jobs()
+
+ if _EdgeWorkerCli.drain or datetime.now().timestamp() -
self.last_hb.timestamp() > self.hb_interval:
+ self.heartbeat()
+ self.last_hb = datetime.now()
+
+ if not new_job:
+ self.interruptible_sleep()
+
+ def fetch_job(self) -> bool:
+ """Fetch and start a new job from central site."""
+ logger.debug("Attempting to fetch a new job...")
+ edge_job = EdgeJob.reserve_task(self.hostname, self.queues)
+ if edge_job:
+ logger.info("Received job: %s", edge_job)
+ env = os.environ.copy()
+ env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
+ env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge",
"api_url")
+ env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
+ process = Popen(edge_job.command, close_fds=True, env=env)
+ logfile = EdgeLogs.logfile_path(edge_job.key)
+ self.jobs.append(_Job(edge_job, process, logfile, 0))
+ EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING)
+ return True
+
+ logger.info("No new job to process%s", f", {len(self.jobs)} still
running" if self.jobs else "")
+ return False
+
+ def check_running_jobs(self) -> None:
+ """Check which of the running tasks/jobs are completed and report
back."""
+ for i in range(len(self.jobs) - 1, -1, -1):
+ job = self.jobs[i]
+ job.process.poll()
+ if job.process.returncode is not None:
+ self.jobs.remove(job)
+ if job.process.returncode == 0:
+ logger.info("Job completed: %s", job.edge_job)
+ EdgeJob.set_state(job.edge_job.key,
TaskInstanceState.SUCCESS)
+ else:
+ logger.error("Job failed: %s", job.edge_job)
+ EdgeJob.set_state(job.edge_job.key,
TaskInstanceState.FAILED)
+ if job.logfile.exists() and job.logfile.stat().st_size >
job.logsize:
+ with job.logfile.open("r") as logfile:
+ logfile.seek(job.logsize, os.SEEK_SET)
+ logdata = logfile.read()
+ EdgeLogs.push_logs(
+ task=job.edge_job.key,
+ log_chunk_time=datetime.now(),
+ log_chunk_data=logdata,
+ )
+ job.logsize += len(logdata)
+
+ def heartbeat(self) -> None:
+ """Report liveness state of worker to central site with stats."""
+ state = (
+ (EdgeWorkerState.TERMINATING if _EdgeWorkerCli.drain else
EdgeWorkerState.RUNNING)
+ if self.jobs
+ else EdgeWorkerState.IDLE
+ )
+ sysinfo = _get_sysinfo()
+ EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo)
+
+ def interruptible_sleep(self):
+ """Sleeps but stops sleeping if drain is made."""
+ drain_before_sleep = _EdgeWorkerCli.drain
+ for _ in range(0, self.job_poll_interval * 10):
+ sleep(0.1)
+ if drain_before_sleep != _EdgeWorkerCli.drain:
+ return
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def worker(args):
+ """Start Airflow Edge Worker."""
+ print(settings.HEADER)
+ print(EDGE_WORKER_HEADER)
+
+ edge_worker = _EdgeWorkerCli(
+ pid_file_path=_pid_file_path(args.pid),
+ hostname=args.edge_hostname or _hostname(),
+ queues=args.queues.split(",") if args.queues else None,
+ concurrency=args.concurrency,
+ job_poll_interval=conf.getint("edge", "job_poll_interval"),
+ heartbeat_interval=conf.getint("edge", "heartbeat_interval"),
+ )
+ edge_worker.start()
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def stop(args):
+ """Stop a running Airflow Edge Worker."""
+ pid = read_pid_from_pidfile(_pid_file_path(args.pid))
+ # Send SIGINT
+ if pid:
+ logger.warning("Sending SIGINT to worker pid %i.", pid)
+ worker_process = psutil.Process(pid)
+ worker_process.send_signal(signal.SIGINT)
+ else:
+ logger.warning("Could not find PID of worker.")
+
+
+ARG_CONCURRENCY = Arg(
+ ("-c", "--concurrency"),
+ type=int,
+ help="The number of worker processes",
+ default=conf.getint("edge", "worker_concurrency", fallback=8),
+)
+ARG_QUEUES = Arg(
+ ("-q", "--queues"),
+ help="Comma delimited list of queues to serve, serve all queues if not
provided.",
+)
+ARG_EDGE_HOSTNAME = Arg(
+ ("-H", "--edge-hostname"),
+ help="Set the hostname of worker if you have multiple workers on a single
machine",
+)
+EDGE_COMMANDS: list[ActionCommand] = [
+ ActionCommand(
+ name=worker.__name__,
+ help=worker.__doc__,
+ func=worker,
+ args=(
+ ARG_CONCURRENCY,
+ ARG_QUEUES,
+ ARG_EDGE_HOSTNAME,
+ ARG_PID,
+ ARG_VERBOSE,
+ ),
+ ),
+ ActionCommand(
+ name=stop.__name__,
+ help=stop.__doc__,
+ func=stop,
+ args=(
+ ARG_PID,
+ ARG_VERBOSE,
+ ),
+ ),
+]
diff --git a/tests/providers/edge/cli/__init__.py
b/tests/providers/edge/cli/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/providers/edge/cli/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/edge/cli/test_edge_command.py
b/tests/providers/edge/cli/test_edge_command.py
new file mode 100644
index 0000000000..398c221db0
--- /dev/null
+++ b/tests/providers/edge/cli/test_edge_command.py
@@ -0,0 +1,259 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+from pathlib import Path
+from subprocess import Popen
+from unittest.mock import patch
+
+import pytest
+import time_machine
+
+from airflow.exceptions import AirflowException
+from airflow.providers.edge.cli.edge_command import (
+ _EdgeWorkerCli,
+ _get_sysinfo,
+ _Job,
+)
+from airflow.providers.edge.models.edge_job import EdgeJob
+from airflow.providers.edge.models.edge_worker import EdgeWorker,
EdgeWorkerState
+from airflow.utils.state import TaskInstanceState
+from tests.test_utils.config import conf_vars
+
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+# Ignore the following error for mocking
+# mypy: disable-error-code="attr-defined"
+
+
+def test_get_sysinfo():
+ sysinfo = _get_sysinfo()
+ assert "airflow_version" in sysinfo
+ assert "edge_provider_version" in sysinfo
+
+
+class TestEdgeWorkerCli:
+ @pytest.fixture
+ def dummy_joblist(self, tmp_path: Path) -> list[_Job]:
+ logfile = tmp_path / "file.log"
+ logfile.touch()
+
+ class MockPopen(Popen):
+ generated_returncode = None
+
+ def __init__(self):
+ pass
+
+ def poll(self):
+ pass
+
+ @property
+ def returncode(self):
+ return self.generated_returncode
+
+ return [
+ _Job(
+ edge_job=EdgeJob(
+ dag_id="test",
+ task_id="test1",
+ run_id="test",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.RUNNING,
+ queue="test",
+ command=["test", "command"],
+ queued_dttm=datetime.now(),
+ edge_worker=None,
+ last_update=None,
+ ),
+ process=MockPopen(),
+ logfile=logfile,
+ logsize=0,
+ ),
+ ]
+
+ @pytest.fixture
+ def worker_with_job(self, tmp_path: Path, dummy_joblist: list[_Job]) ->
_EdgeWorkerCli:
+ test_worker = _EdgeWorkerCli(tmp_path / "dummy.pid", "dummy", None, 8,
5, 5)
+ test_worker.jobs = dummy_joblist
+ return test_worker
+
+ @pytest.mark.parametrize(
+ "reserve_result, fetch_result, expected_calls",
+ [
+ pytest.param(None, False, (0, 0), id="no_job"),
+ pytest.param(
+ EdgeJob(
+ dag_id="test",
+ task_id="test",
+ run_id="test",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.QUEUED,
+ queue="test",
+ command=["test", "command"],
+ queued_dttm=datetime.now(),
+ edge_worker=None,
+ last_update=None,
+ ),
+ True,
+ (1, 1),
+ id="new_job",
+ ),
+ ],
+ )
+ @patch("airflow.providers.edge.models.edge_job.EdgeJob.reserve_task")
+ @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.logfile_path")
+ @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+ @patch("subprocess.Popen")
+ def test_fetch_job(
+ self,
+ mock_popen,
+ mock_set_state,
+ mock_logfile_path,
+ mock_reserve_task,
+ reserve_result,
+ fetch_result,
+ expected_calls,
+ worker_with_job: _EdgeWorkerCli,
+ ):
+ logfile_path_call_count, set_state_call_count = expected_calls
+ mock_reserve_task.side_effect = [reserve_result]
+ mock_popen.side_effect = ["dummy"]
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ got_job = worker_with_job.fetch_job()
+ mock_reserve_task.assert_called_once()
+ assert got_job == fetch_result
+ assert mock_logfile_path.call_count == logfile_path_call_count
+ assert mock_set_state.call_count == set_state_call_count
+
+ def test_check_running_jobs_running(self, worker_with_job: _EdgeWorkerCli):
+ worker_with_job.jobs[0].process.generated_returncode = None
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ worker_with_job.check_running_jobs()
+ assert len(worker_with_job.jobs) == 1
+
+ @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+ def test_check_running_jobs_success(self, mock_set_state, worker_with_job:
_EdgeWorkerCli):
+ job = worker_with_job.jobs[0]
+ job.process.generated_returncode = 0
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ worker_with_job.check_running_jobs()
+ assert len(worker_with_job.jobs) == 0
+ mock_set_state.assert_called_once_with(job.edge_job.key,
TaskInstanceState.SUCCESS)
+
+ @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state")
+ def test_check_running_jobs_failed(self, mock_set_state, worker_with_job:
_EdgeWorkerCli):
+ job = worker_with_job.jobs[0]
+ job.process.generated_returncode = 42
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ worker_with_job.check_running_jobs()
+ assert len(worker_with_job.jobs) == 0
+ mock_set_state.assert_called_once_with(job.edge_job.key,
TaskInstanceState.FAILED)
+
+ @time_machine.travel(datetime.now(), tick=False)
+ @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
+ def test_check_running_jobs_log_push(self, mock_push_logs,
worker_with_job: _EdgeWorkerCli):
+ job = worker_with_job.jobs[0]
+ job.process.generated_returncode = None
+ job.logfile.write_text("some log content")
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ worker_with_job.check_running_jobs()
+ assert len(worker_with_job.jobs) == 1
+ mock_push_logs.assert_called_once_with(
+ task=job.edge_job.key, log_chunk_time=datetime.now(),
log_chunk_data="some log content"
+ )
+
+ @time_machine.travel(datetime.now(), tick=False)
+ @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs")
+ def test_check_running_jobs_log_push_increment(self, mock_push_logs,
worker_with_job: _EdgeWorkerCli):
+ job = worker_with_job.jobs[0]
+ job.process.generated_returncode = None
+ job.logfile.write_text("hello ")
+ job.logsize = job.logfile.stat().st_size
+ job.logfile.write_text("hello world")
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ worker_with_job.check_running_jobs()
+ assert len(worker_with_job.jobs) == 1
+ mock_push_logs.assert_called_once_with(
+ task=job.edge_job.key, log_chunk_time=datetime.now(),
log_chunk_data="world"
+ )
+
+ @pytest.mark.parametrize(
+ "drain, jobs, expected_state",
+ [
+ pytest.param(False, True, EdgeWorkerState.RUNNING,
id="running_jobs"),
+ pytest.param(True, True, EdgeWorkerState.TERMINATING,
id="shutting_down"),
+ pytest.param(False, False, EdgeWorkerState.IDLE, id="idle"),
+ ],
+ )
+ @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+ def test_heartbeat(self, mock_set_state, drain, jobs, expected_state,
worker_with_job: _EdgeWorkerCli):
+ if not jobs:
+ worker_with_job.jobs = []
+ _EdgeWorkerCli.drain = drain
+ with conf_vars({("edge", "api_url"): "https://mock.server"}):
+ worker_with_job.heartbeat()
+ assert mock_set_state.call_args.args[1] == expected_state
+
+
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+ def test_start_missing_apiserver(self, mock_register_worker,
worker_with_job: _EdgeWorkerCli):
+ mock_register_worker.side_effect = AirflowException(
+ "Something with 404:NOT FOUND means API is not active"
+ )
+ with pytest.raises(SystemExit, match=r"API endpoint is not ready"):
+ worker_with_job.start()
+
+
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+ def test_start_server_error(self, mock_register_worker, worker_with_job:
_EdgeWorkerCli):
+ mock_register_worker.side_effect = AirflowException("Something other
error not FourhundretFour")
+ with pytest.raises(SystemExit, match=r"Something other"):
+ worker_with_job.start()
+
+
@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
+ @patch("airflow.providers.edge.cli.edge_command._EdgeWorkerCli.loop")
+ @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state")
+ def test_start_and_run_one(
+ self, mock_set_state, mock_loop, mock_register_worker,
worker_with_job: _EdgeWorkerCli
+ ):
+ mock_register_worker.side_effect = [
+ EdgeWorker(
+ worker_name="test",
+ state=EdgeWorkerState.STARTING,
+ queues=None,
+ first_online=datetime.now(),
+ last_update=datetime.now(),
+ jobs_active=0,
+ jobs_taken=0,
+ jobs_success=0,
+ jobs_failed=0,
+ sysinfo="",
+ )
+ ]
+
+ def stop_running():
+ _EdgeWorkerCli.drain = True
+ worker_with_job.jobs = []
+
+ mock_loop.side_effect = stop_running
+
+ worker_with_job.start()
+
+ mock_register_worker.assert_called_once()
+ mock_loop.assert_called_once()
+ mock_set_state.assert_called_once()
diff --git a/tests/providers/edge/models/test_edge_worker.py
b/tests/providers/edge/models/test_edge_worker.py
index 9eca293baf..f0e0ac9dfa 100644
--- a/tests/providers/edge/models/test_edge_worker.py
+++ b/tests/providers/edge/models/test_edge_worker.py
@@ -20,11 +20,14 @@ from typing import TYPE_CHECKING
import pytest
+from airflow.providers.edge.cli.edge_command import _get_sysinfo
from airflow.providers.edge.models.edge_worker import (
EdgeWorker,
EdgeWorkerModel,
+ EdgeWorkerState,
EdgeWorkerVersionException,
)
+from airflow.utils import timezone
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -63,3 +66,29 @@ class TestEdgeWorker:
EdgeWorker.assert_version(
{"airflow_version": airflow_version, "edge_provider_version":
edge_provider_version}
)
+
+ def test_register_worker(self, session: Session):
+ EdgeWorker.register_worker(
+ "test_worker", EdgeWorkerState.STARTING, queues=None,
sysinfo=_get_sysinfo()
+ )
+
+ worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
+ assert len(worker) == 1
+ assert worker[0].worker_name == "test_worker"
+
+ def test_set_state(self, session: Session):
+ rwm = EdgeWorkerModel(
+ worker_name="test2_worker",
+ state=EdgeWorkerState.IDLE,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ )
+ session.add(rwm)
+ session.commit()
+
+ EdgeWorker.set_state("test2_worker", EdgeWorkerState.RUNNING, 1,
_get_sysinfo())
+
+ worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
+ assert len(worker) == 1
+ assert worker[0].worker_name == "test2_worker"
+ assert worker[0].state == EdgeWorkerState.RUNNING