This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 68adc0e059 Refactor commands to unify daemon context handling (#34945)
68adc0e059 is described below
commit 68adc0e059ac65f20dfc7cf0038edb96b1244d32
Author: Daniel DylÄ…g <[email protected]>
AuthorDate: Tue Oct 24 10:16:47 2023 +0200
Refactor commands to unify daemon context handling (#34945)
---
.pre-commit-config.yaml | 1 +
airflow/cli/commands/celery_command.py | 89 +++++++++----------------
airflow/cli/commands/daemon_utils.py | 82 +++++++++++++++++++++++
airflow/cli/commands/dag_processor_command.py | 32 ++-------
airflow/cli/commands/internal_api_command.py | 95 ++++++++++++---------------
airflow/cli/commands/kerberos_command.py | 29 ++------
airflow/cli/commands/scheduler_command.py | 50 +++++---------
airflow/cli/commands/triggerer_command.py | 48 ++++----------
airflow/cli/commands/webserver_command.py | 93 +++++++++++---------------
tests/cli/commands/test_celery_command.py | 18 ++---
tests/cli/commands/test_kerberos_command.py | 29 ++++----
11 files changed, 256 insertions(+), 310 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 85cfc050c1..d7c10d53ff 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -509,6 +509,7 @@ repos:
^airflow/api_connexion/openapi/v1.yaml$|
^airflow/auth/managers/fab/security_manager/|
^airflow/cli/commands/webserver_command.py$|
+ ^airflow/cli/commands/internal_api_command.py$|
^airflow/config_templates/|
^airflow/models/baseoperator.py$|
^airflow/operators/__init__.py$|
diff --git a/airflow/cli/commands/celery_command.py
b/airflow/cli/commands/celery_command.py
index eb53d6f60d..5e3e01042a 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -23,19 +23,18 @@ import sys
from contextlib import contextmanager
from multiprocessing import Process
-import daemon
import psutil
import sqlalchemy.exc
from celery import maybe_patch_concurrency # type: ignore[attr-defined]
from celery.app.defaults import DEFAULT_TASK_LOG_FMT
from celery.signals import after_setup_logger
-from daemon.pidfile import TimeoutPIDLockFile
from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile
from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
+from airflow.utils.cli import setup_locations
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.serve_logs import serve_logs
@@ -68,28 +67,9 @@ def flower(args):
if args.flower_conf:
options.append(f"--conf={args.flower_conf}")
- if args.daemon:
- pidfile, stdout, stderr, _ = setup_locations(
- process="flower",
- pid=args.pid,
- stdout=args.stdout,
- stderr=args.stderr,
- log=args.log_file,
- )
- with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
- stdout.truncate(0)
- stderr.truncate(0)
-
- ctx = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pidfile, -1),
- stdout=stdout,
- stderr=stderr,
- umask=int(settings.DAEMON_UMASK, 8),
- )
- with ctx:
- celery_app.start(options)
- else:
- celery_app.start(options)
+ run_command_with_daemon_option(
+ args=args, process_name="flower", callback=lambda:
celery_app.start(options)
+ )
@contextmanager
@@ -152,15 +132,6 @@ def worker(args):
if autoscale is None and conf.has_option("celery", "worker_autoscale"):
autoscale = conf.get("celery", "worker_autoscale")
- # Setup locations
- pid_file_path, stdout, stderr, log_file = setup_locations(
- process=WORKER_PROCESS_NAME,
- pid=args.pid,
- stdout=args.stdout,
- stderr=args.stderr,
- log=args.log_file,
- )
-
if hasattr(celery_app.backend, "ResultSession"):
# Pre-create the database tables now, otherwise SQLA via Celery has a
# race condition where one of the subprocesses can die with "Table
@@ -181,6 +152,10 @@ def worker(args):
celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL")
if not celery_log_level:
celery_log_level = conf.get("logging", "LOGGING_LEVEL")
+
+ # Setup pid file location
+ worker_pid_file_path, _, _, _ =
setup_locations(process=WORKER_PROCESS_NAME, pid=args.pid)
+
# Setup Celery worker
options = [
"worker",
@@ -195,7 +170,7 @@ def worker(args):
"--loglevel",
celery_log_level,
"--pidfile",
- pid_file_path,
+ worker_pid_file_path,
]
if autoscale:
options.extend(["--autoscale", autoscale])
@@ -214,33 +189,31 @@ def worker(args):
# executed.
maybe_patch_concurrency(["-P", pool])
- if args.daemon:
- # Run Celery worker as daemon
- handle = setup_logging(log_file)
-
- with open(stdout, "a") as stdout_handle, open(stderr, "a") as
stderr_handle:
- if args.umask:
- umask = args.umask
- else:
- umask = conf.get("celery", "worker_umask",
fallback=settings.DAEMON_UMASK)
-
- stdout_handle.truncate(0)
- stderr_handle.truncate(0)
-
- daemon_context = daemon.DaemonContext(
- files_preserve=[handle],
- umask=int(umask, 8),
- stdout=stdout_handle,
- stderr=stderr_handle,
- )
- with daemon_context, _serve_logs(skip_serve_logs):
- celery_app.worker_main(options)
+ _, stdout, stderr, log_file = setup_locations(
+ process=WORKER_PROCESS_NAME,
+ stdout=args.stdout,
+ stderr=args.stderr,
+ log=args.log_file,
+ )
- else:
- # Run Celery worker in the same process
+ def run_celery_worker():
with _serve_logs(skip_serve_logs):
celery_app.worker_main(options)
+ if args.umask:
+ umask = args.umask
+ else:
+ umask = conf.get("celery", "worker_umask",
fallback=settings.DAEMON_UMASK)
+
+ run_command_with_daemon_option(
+ args=args,
+ process_name=WORKER_PROCESS_NAME,
+ callback=run_celery_worker,
+ should_setup_logging=True,
+ umask=umask,
+ pid_file=worker_pid_file_path,
+ )
+
@cli_utils.action_cli
@providers_configuration_loaded
diff --git a/airflow/cli/commands/daemon_utils.py
b/airflow/cli/commands/daemon_utils.py
new file mode 100644
index 0000000000..9184b1f7db
--- /dev/null
+++ b/airflow/cli/commands/daemon_utils.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import signal
+from argparse import Namespace
+from typing import Callable
+
+from daemon import daemon
+from daemon.pidfile import TimeoutPIDLockFile
+
+from airflow import settings
+from airflow.utils.cli import setup_locations, setup_logging, sigint_handler,
sigquit_handler
+from airflow.utils.process_utils import check_if_pidfile_process_is_running
+
+
+def run_command_with_daemon_option(
+ *,
+ args: Namespace,
+ process_name: str,
+ callback: Callable,
+ should_setup_logging: bool = False,
+ umask: str = settings.DAEMON_UMASK,
+ pid_file: str | None = None,
+):
+ """Run the command in a daemon process if daemon mode enabled or within
this process if not.
+
+ :param args: the set of arguments passed to the original CLI command
+ :param process_name: process name used in naming log and PID files for the
daemon
+ :param callback: the actual command to run with or without daemon context
+ :param should_setup_logging: if true, then a log file handler for the
daemon process will be created
+ :param umask: file access creation mask ("umask") to set for the process
on daemon start
+ :param pid_file: if specified, this file path us used to store daemon
process PID.
+ If not specified, a file path is generated with the default pattern.
+ """
+ if args.daemon:
+ pid, stdout, stderr, log_file = setup_locations(
+ process=process_name, stdout=args.stdout, stderr=args.stderr,
log=args.log_file
+ )
+ if pid_file:
+ pid = pid_file
+
+ # Check if the process is already running; if not but a pidfile
exists, clean it up
+ check_if_pidfile_process_is_running(pid_file=pid,
process_name=process_name)
+
+ if should_setup_logging:
+ files_preserve = [setup_logging(log_file)]
+ else:
+ files_preserve = None
+ with open(stdout, "a") as stdout_handle, open(stderr, "a") as
stderr_handle:
+ stdout_handle.truncate(0)
+ stderr_handle.truncate(0)
+
+ ctx = daemon.DaemonContext(
+ pidfile=TimeoutPIDLockFile(pid, -1),
+ files_preserve=files_preserve,
+ stdout=stdout_handle,
+ stderr=stderr_handle,
+ umask=int(umask, 8),
+ )
+
+ with ctx:
+ callback()
+ else:
+ signal.signal(signal.SIGINT, sigint_handler)
+ signal.signal(signal.SIGTERM, sigint_handler)
+ signal.signal(signal.SIGQUIT, sigquit_handler)
+ callback()
diff --git a/airflow/cli/commands/dag_processor_command.py
b/airflow/cli/commands/dag_processor_command.py
index cf880f6622..85bef2727d 100644
--- a/airflow/cli/commands/dag_processor_command.py
+++ b/airflow/cli/commands/dag_processor_command.py
@@ -21,16 +21,12 @@ import logging
from datetime import timedelta
from typing import Any
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
-from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner
from airflow.jobs.job import Job, run_job
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
log = logging.getLogger(__name__)
@@ -66,23 +62,9 @@ def dag_processor(args):
job_runner = _create_dag_processor_job_runner(args)
- if args.daemon:
- pid, stdout, stderr, log_file = setup_locations(
- "dag-processor", args.pid, args.stdout, args.stderr, args.log_file
- )
- handle = setup_logging(log_file)
- with open(stdout, "a") as stdout_handle, open(stderr, "a") as
stderr_handle:
- stdout_handle.truncate(0)
- stderr_handle.truncate(0)
-
- ctx = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pid, -1),
- files_preserve=[handle],
- stdout=stdout_handle,
- stderr=stderr_handle,
- umask=int(settings.DAEMON_UMASK, 8),
- )
- with ctx:
- run_job(job=job_runner.job,
execute_callable=job_runner._execute)
- else:
- run_job(job=job_runner.job, execute_callable=job_runner._execute)
+ run_command_with_daemon_option(
+ args=args,
+ process_name="dag-processor",
+ callback=lambda: run_job(job=job_runner.job,
execute_callable=job_runner._execute),
+ should_setup_logging=True,
+ )
diff --git a/airflow/cli/commands/internal_api_command.py
b/airflow/cli/commands/internal_api_command.py
index 73ed2e2501..f558c89cab 100644
--- a/airflow/cli/commands/internal_api_command.py
+++ b/airflow/cli/commands/internal_api_command.py
@@ -28,9 +28,7 @@ from pathlib import Path
from tempfile import gettempdir
from time import sleep
-import daemon
import psutil
-from daemon.pidfile import TimeoutPIDLockFile
from flask import Flask
from flask_appbuilder import SQLA
from flask_caching import Cache
@@ -40,14 +38,14 @@ from sqlalchemy.engine.url import make_url
from airflow import settings
from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.cli.commands.webserver_command import GunicornMonitor
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.logging_config import configure_logging
from airflow.models import import_all_models
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
-from airflow.utils.process_utils import check_if_pidfile_process_is_running
+from airflow.utils.cli import setup_locations
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.www.extensions.init_dagbag import init_dagbag
from airflow.www.extensions.init_jinja_globals import init_jinja_globals
@@ -81,13 +79,6 @@ def internal_api(args):
host=args.hostname,
)
else:
- pid_file, stdout, stderr, log_file = setup_locations(
- "internal-api", args.pid, args.stdout, args.stderr, args.log_file
- )
-
- # Check if Internal APi is already running if not, remove old pidfile
- check_if_pidfile_process_is_running(pid_file=pid_file,
process_name="internal-api")
-
log.info(
textwrap.dedent(
f"""\
@@ -101,6 +92,8 @@ def internal_api(args):
)
)
+ pid_file, _, _, _ = setup_locations("internal-api", pid=args.pid)
+
run_args = [
sys.executable,
"-m",
@@ -137,25 +130,27 @@ def internal_api(args):
# then have a copy of the app
run_args += ["--preload"]
- gunicorn_master_proc: psutil.Process | None = None
-
- def kill_proc(signum, _):
+ def kill_proc(signum: int, gunicorn_master_proc: psutil.Process |
subprocess.Popen):
log.info("Received signal: %s. Closing gunicorn.", signum)
gunicorn_master_proc.terminate()
with suppress(TimeoutError):
gunicorn_master_proc.wait(timeout=30)
- if gunicorn_master_proc.is_running():
+ if isinstance(gunicorn_master_proc, subprocess.Popen):
+ still_running = gunicorn_master_proc.poll() is not None
+ else:
+ still_running = gunicorn_master_proc.is_running()
+ if still_running:
gunicorn_master_proc.kill()
sys.exit(0)
- def monitor_gunicorn(gunicorn_master_pid: int):
+ def monitor_gunicorn(gunicorn_master_proc: psutil.Process |
subprocess.Popen):
# Register signal handlers
- signal.signal(signal.SIGINT, kill_proc)
- signal.signal(signal.SIGTERM, kill_proc)
+ signal.signal(signal.SIGINT, lambda signum, _: kill_proc(signum,
gunicorn_master_proc))
+ signal.signal(signal.SIGTERM, lambda signum, _: kill_proc(signum,
gunicorn_master_proc))
# These run forever until SIG{INT, TERM, KILL, ...} signal is sent
GunicornMonitor(
- gunicorn_master_pid=gunicorn_master_pid,
+ gunicorn_master_pid=gunicorn_master_proc.pid,
num_workers_expected=num_workers,
master_timeout=120,
worker_refresh_interval=30,
@@ -163,45 +158,39 @@ def internal_api(args):
reload_on_plugin_change=False,
).start()
+ def start_and_monitor_gunicorn(args):
+ if args.daemon:
+ subprocess.Popen(run_args, close_fds=True)
+
+ # Reading pid of gunicorn master as it will be different that
+ # the one of process spawned above.
+ gunicorn_master_proc_pid = None
+ while not gunicorn_master_proc_pid:
+ sleep(0.1)
+ gunicorn_master_proc_pid = read_pid_from_pidfile(pid_file)
+
+ # Run Gunicorn monitor
+ gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid)
+ monitor_gunicorn(gunicorn_master_proc)
+ else:
+ with subprocess.Popen(run_args, close_fds=True) as
gunicorn_master_proc:
+ monitor_gunicorn(gunicorn_master_proc)
+
if args.daemon:
# This makes possible errors get reported before daemonization
os.environ["SKIP_DAGS_PARSING"] = "True"
- app = create_app(None)
+ create_app(None)
os.environ.pop("SKIP_DAGS_PARSING")
- handle = setup_logging(log_file)
-
- pid_path = Path(pid_file)
- pidlock_path =
pid_path.with_name(f"{pid_path.stem}-monitor{pid_path.suffix}")
-
- with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
- stdout.truncate(0)
- stderr.truncate(0)
-
- ctx = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pidlock_path, -1),
- files_preserve=[handle],
- stdout=stdout,
- stderr=stderr,
- umask=int(settings.DAEMON_UMASK, 8),
- )
- with ctx:
- subprocess.Popen(run_args, close_fds=True)
-
- # Reading pid of gunicorn main process as it will be
different that
- # the one of process spawned above.
- gunicorn_master_proc_pid = None
- while not gunicorn_master_proc_pid:
- sleep(0.1)
- gunicorn_master_proc_pid =
read_pid_from_pidfile(pid_file)
-
- # Run Gunicorn monitor
- gunicorn_master_proc =
psutil.Process(gunicorn_master_proc_pid)
- monitor_gunicorn(gunicorn_master_proc.pid)
-
- else:
- with subprocess.Popen(run_args, close_fds=True) as
gunicorn_master_proc:
- monitor_gunicorn(gunicorn_master_proc.pid)
+ pid_file_path = Path(pid_file)
+ monitor_pid_file =
str(pid_file_path.with_name(f"{pid_file_path.stem}-monitor{pid_file_path.suffix}"))
+ run_command_with_daemon_option(
+ args=args,
+ process_name="internal-api",
+ callback=lambda: start_and_monitor_gunicorn(args),
+ should_setup_logging=True,
+ pid_file=monitor_pid_file,
+ )
def create_app(config=None, testing=False):
diff --git a/airflow/cli/commands/kerberos_command.py
b/airflow/cli/commands/kerberos_command.py
index 4dd63d52eb..8d33e7f8ef 100644
--- a/airflow/cli/commands/kerberos_command.py
+++ b/airflow/cli/commands/kerberos_command.py
@@ -17,13 +17,10 @@
"""Kerberos command."""
from __future__ import annotations
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.security import kerberos as krb
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
@@ -33,22 +30,8 @@ def kerberos(args):
"""Start a kerberos ticket renewer."""
print(settings.HEADER)
- if args.daemon:
- pid, stdout, stderr, _ = setup_locations(
- "kerberos", args.pid, args.stdout, args.stderr, args.log_file
- )
- with open(stdout, "a") as stdout_handle, open(stderr, "a") as
stderr_handle:
- stdout_handle.truncate(0)
- stderr_handle.truncate(0)
-
- ctx = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pid, -1),
- stdout=stdout_handle,
- stderr=stderr_handle,
- umask=int(settings.DAEMON_UMASK, 8),
- )
-
- with ctx:
- krb.run(principal=args.principal, keytab=args.keytab)
- else:
- krb.run(principal=args.principal, keytab=args.keytab)
+ run_command_with_daemon_option(
+ args=args,
+ process_name="kerberos",
+ callback=lambda: krb.run(principal=args.principal, keytab=args.keytab),
+ )
diff --git a/airflow/cli/commands/scheduler_command.py
b/airflow/cli/commands/scheduler_command.py
index fd25951ad3..fef0b97b2d 100644
--- a/airflow/cli/commands/scheduler_command.py
+++ b/airflow/cli/commands/scheduler_command.py
@@ -18,31 +18,33 @@
from __future__ import annotations
import logging
-import signal
+from argparse import Namespace
from contextlib import contextmanager
from multiprocessing import Process
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
from airflow import settings
from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.job import Job, run_job
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import process_subdir, setup_locations, setup_logging,
sigint_handler, sigquit_handler
+from airflow.utils.cli import process_subdir
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.scheduler_health import serve_health_check
log = logging.getLogger(__name__)
-def _run_scheduler_job(job_runner: SchedulerJobRunner, *, skip_serve_logs:
bool) -> None:
+def _run_scheduler_job(args) -> None:
+ job_runner = SchedulerJobRunner(
+ job=Job(), subdir=process_subdir(args.subdir), num_runs=args.num_runs,
do_pickle=args.do_pickle
+ )
+
ExecutorLoader.validate_database_executor_compatibility(job_runner.job.executor)
InternalApiConfig.force_database_direct_access()
enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")
- with _serve_logs(skip_serve_logs),
_serve_health_check(enable_health_check):
+ with _serve_logs(args.skip_serve_logs),
_serve_health_check(enable_health_check):
try:
run_job(job=job_runner.job, execute_callable=job_runner._execute)
except Exception:
@@ -51,38 +53,16 @@ def _run_scheduler_job(job_runner: SchedulerJobRunner, *,
skip_serve_logs: bool)
@cli_utils.action_cli
@providers_configuration_loaded
-def scheduler(args):
+def scheduler(args: Namespace):
"""Start Airflow Scheduler."""
print(settings.HEADER)
- job_runner = SchedulerJobRunner(
- job=Job(), subdir=process_subdir(args.subdir), num_runs=args.num_runs,
do_pickle=args.do_pickle
+ run_command_with_daemon_option(
+ args=args,
+ process_name="scheduler",
+ callback=lambda: _run_scheduler_job(args),
+ should_setup_logging=True,
)
-
ExecutorLoader.validate_database_executor_compatibility(job_runner.job.executor)
-
- if args.daemon:
- pid, stdout, stderr, log_file = setup_locations(
- "scheduler", args.pid, args.stdout, args.stderr, args.log_file
- )
- handle = setup_logging(log_file)
- with open(stdout, "a") as stdout_handle, open(stderr, "a") as
stderr_handle:
- stdout_handle.truncate(0)
- stderr_handle.truncate(0)
-
- ctx = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pid, -1),
- files_preserve=[handle],
- stdout=stdout_handle,
- stderr=stderr_handle,
- umask=int(settings.DAEMON_UMASK, 8),
- )
- with ctx:
- _run_scheduler_job(job_runner,
skip_serve_logs=args.skip_serve_logs)
- else:
- signal.signal(signal.SIGINT, sigint_handler)
- signal.signal(signal.SIGTERM, sigint_handler)
- signal.signal(signal.SIGQUIT, sigquit_handler)
- _run_scheduler_job(job_runner, skip_serve_logs=args.skip_serve_logs)
@contextmanager
diff --git a/airflow/cli/commands/triggerer_command.py
b/airflow/cli/commands/triggerer_command.py
index 5ddb4e23b6..3479480dbf 100644
--- a/airflow/cli/commands/triggerer_command.py
+++ b/airflow/cli/commands/triggerer_command.py
@@ -17,21 +17,17 @@
"""Triggerer command."""
from __future__ import annotations
-import signal
from contextlib import contextmanager
from functools import partial
from multiprocessing import Process
from typing import Generator
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.jobs.job import Job, run_job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging, sigint_handler,
sigquit_handler
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.serve_logs import serve_logs
@@ -51,6 +47,12 @@ def _serve_logs(skip_serve_logs: bool = False) ->
Generator[None, None, None]:
sub_proc.terminate()
+def triggerer_run(skip_serve_logs: bool, capacity: int, triggerer_heartrate:
float):
+ with _serve_logs(skip_serve_logs):
+ triggerer_job_runner =
TriggererJobRunner(job=Job(heartrate=triggerer_heartrate), capacity=capacity)
+ run_job(job=triggerer_job_runner.job,
execute_callable=triggerer_job_runner._execute)
+
+
@cli_utils.action_cli
@providers_configuration_loaded
def triggerer(args):
@@ -59,33 +61,9 @@ def triggerer(args):
print(settings.HEADER)
triggerer_heartrate = conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC")
- if args.daemon:
- pid, stdout, stderr, log_file = setup_locations(
- "triggerer", args.pid, args.stdout, args.stderr, args.log_file
- )
- handle = setup_logging(log_file)
- with open(stdout, "a") as stdout_handle, open(stderr, "a") as
stderr_handle:
- stdout_handle.truncate(0)
- stderr_handle.truncate(0)
-
- daemon_context = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pid, -1),
- files_preserve=[handle],
- stdout=stdout_handle,
- stderr=stderr_handle,
- umask=int(settings.DAEMON_UMASK, 8),
- )
- with daemon_context, _serve_logs(args.skip_serve_logs):
- triggerer_job_runner = TriggererJobRunner(
- job=Job(heartrate=triggerer_heartrate),
capacity=args.capacity
- )
- run_job(job=triggerer_job_runner.job,
execute_callable=triggerer_job_runner._execute)
- else:
- signal.signal(signal.SIGINT, sigint_handler)
- signal.signal(signal.SIGTERM, sigint_handler)
- signal.signal(signal.SIGQUIT, sigquit_handler)
- with _serve_logs(args.skip_serve_logs):
- triggerer_job_runner = TriggererJobRunner(
- job=Job(heartrate=triggerer_heartrate), capacity=args.capacity
- )
- run_job(job=triggerer_job_runner.job,
execute_callable=triggerer_job_runner._execute)
+ run_command_with_daemon_option(
+ args=args,
+ process_name="triggerer",
+ callback=lambda: triggerer_run(args.skip_serve_logs, args.capacity,
triggerer_heartrate),
+ should_setup_logging=True,
+ )
diff --git a/airflow/cli/commands/webserver_command.py
b/airflow/cli/commands/webserver_command.py
index 5ae601b428..4cb7939fd7 100644
--- a/airflow/cli/commands/webserver_command.py
+++ b/airflow/cli/commands/webserver_command.py
@@ -27,26 +27,21 @@ import time
from contextlib import suppress
from pathlib import Path
from time import sleep
-from typing import TYPE_CHECKING, NoReturn
+from typing import NoReturn
-import daemon
import psutil
-from daemon.pidfile import TimeoutPIDLockFile
from lockfile.pidlockfile import read_pid_from_pidfile
from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowWebServerTimeout
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
+from airflow.utils.cli import setup_locations
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.process_utils import check_if_pidfile_process_is_running
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
-if TYPE_CHECKING:
- import types
-
log = logging.getLogger(__name__)
@@ -367,13 +362,6 @@ def webserver(args):
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None,
)
else:
- pid_file, stdout, stderr, log_file = setup_locations(
- "webserver", args.pid, args.stdout, args.stderr, args.log_file
- )
-
- # Check if webserver is already running if not, remove old pidfile
- check_if_pidfile_process_is_running(pid_file=pid_file,
process_name="webserver")
-
print(
textwrap.dedent(
f"""\
@@ -387,6 +375,7 @@ def webserver(args):
)
)
+ pid_file, _, _, _ = setup_locations("webserver", pid=args.pid)
run_args = [
sys.executable,
"-m",
@@ -436,9 +425,7 @@ def webserver(args):
# all writing to the database at the same time, we use the
--preload option.
run_args += ["--preload"]
- gunicorn_master_proc: psutil.Process | subprocess.Popen
-
- def kill_proc(signum: int, frame: types.FrameType | None) -> NoReturn:
+ def kill_proc(signum: int, gunicorn_master_proc: psutil.Process |
subprocess.Popen) -> NoReturn:
log.info("Received signal: %s. Closing gunicorn.", signum)
gunicorn_master_proc.terminate()
with suppress(TimeoutError):
@@ -451,14 +438,14 @@ def webserver(args):
gunicorn_master_proc.kill()
sys.exit(0)
- def monitor_gunicorn(gunicorn_master_pid: int) -> NoReturn:
+ def monitor_gunicorn(gunicorn_master_proc: psutil.Process |
subprocess.Popen) -> NoReturn:
# Register signal handlers
- signal.signal(signal.SIGINT, kill_proc)
- signal.signal(signal.SIGTERM, kill_proc)
+ signal.signal(signal.SIGINT, lambda signum, _: kill_proc(signum,
gunicorn_master_proc))
+ signal.signal(signal.SIGTERM, lambda signum, _: kill_proc(signum,
gunicorn_master_proc))
# These run forever until SIG{INT, TERM, KILL, ...} signal is sent
GunicornMonitor(
- gunicorn_master_pid=gunicorn_master_pid,
+ gunicorn_master_pid=gunicorn_master_proc.pid,
num_workers_expected=num_workers,
master_timeout=conf.getint("webserver",
"web_server_master_timeout"),
worker_refresh_interval=conf.getint("webserver",
"worker_refresh_interval", fallback=30),
@@ -468,42 +455,36 @@ def webserver(args):
),
).start()
+ def start_and_monitor_gunicorn(args):
+ if args.daemon:
+ subprocess.Popen(run_args, close_fds=True)
+
+ # Reading pid of gunicorn master as it will be different that
+ # the one of process spawned above.
+ gunicorn_master_proc_pid = None
+ while not gunicorn_master_proc_pid:
+ sleep(0.1)
+ gunicorn_master_proc_pid = read_pid_from_pidfile(pid_file)
+
+ # Run Gunicorn monitor
+ gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid)
+ monitor_gunicorn(gunicorn_master_proc)
+ else:
+ with subprocess.Popen(run_args, close_fds=True) as
gunicorn_master_proc:
+ monitor_gunicorn(gunicorn_master_proc)
+
if args.daemon:
# This makes possible errors get reported before daemonization
os.environ["SKIP_DAGS_PARSING"] = "True"
- app = create_app(None)
+ create_app(None)
os.environ.pop("SKIP_DAGS_PARSING")
- handle = setup_logging(log_file)
-
- pid_path = Path(pid_file)
- pidlock_path =
pid_path.with_name(f"{pid_path.stem}-monitor{pid_path.suffix}")
-
- with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
- stdout.truncate(0)
- stderr.truncate(0)
-
- ctx = daemon.DaemonContext(
- pidfile=TimeoutPIDLockFile(pidlock_path, -1),
- files_preserve=[handle],
- stdout=stdout,
- stderr=stderr,
- umask=int(settings.DAEMON_UMASK, 8),
- )
- with ctx:
- subprocess.Popen(run_args, close_fds=True)
-
- # Reading pid of gunicorn master as it will be different
that
- # the one of process spawned above.
- gunicorn_master_proc_pid = None
- while not gunicorn_master_proc_pid:
- sleep(0.1)
- gunicorn_master_proc_pid =
read_pid_from_pidfile(pid_file)
-
- # Run Gunicorn monitor
- gunicorn_master_proc =
psutil.Process(gunicorn_master_proc_pid)
- monitor_gunicorn(gunicorn_master_proc.pid)
-
- else:
- with subprocess.Popen(run_args, close_fds=True) as
gunicorn_master_proc:
- monitor_gunicorn(gunicorn_master_proc.pid)
+ pid_file_path = Path(pid_file)
+ monitor_pid_file =
str(pid_file_path.with_name(f"{pid_file_path.stem}-monitor{pid_file_path.suffix}"))
+ run_command_with_daemon_option(
+ args=args,
+ process_name="webserver",
+ callback=lambda: start_and_monitor_gunicorn(args),
+ should_setup_logging=True,
+ pid_file=monitor_pid_file,
+ )
diff --git a/tests/cli/commands/test_celery_command.py
b/tests/cli/commands/test_celery_command.py
index ae968f1171..02f26d7f23 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -266,9 +266,9 @@ class TestFlowerCommand:
]
)
- @mock.patch("airflow.cli.commands.celery_command.TimeoutPIDLockFile")
- @mock.patch("airflow.cli.commands.celery_command.setup_locations")
- @mock.patch("airflow.cli.commands.celery_command.daemon")
+ @mock.patch("airflow.cli.commands.daemon_utils.TimeoutPIDLockFile")
+ @mock.patch("airflow.cli.commands.daemon_utils.setup_locations")
+ @mock.patch("airflow.cli.commands.daemon_utils.daemon")
@mock.patch("airflow.providers.celery.executors.celery_executor.app")
def test_run_command_daemon(self, mock_celery_app, mock_daemon,
mock_setup_locations, mock_pid_file):
mock_setup_locations.return_value = (
@@ -305,7 +305,7 @@ class TestFlowerCommand:
]
)
mock_open = mock.mock_open()
- with mock.patch("airflow.cli.commands.celery_command.open", mock_open):
+ with mock.patch("airflow.cli.commands.daemon_utils.open", mock_open):
celery_command.flower(args)
mock_celery_app.start.assert_called_once_with(
@@ -320,11 +320,12 @@ class TestFlowerCommand:
"--conf=flower_config",
]
)
- assert mock_daemon.mock_calls == [
+ assert mock_daemon.mock_calls[:3] == [
mock.call.DaemonContext(
pidfile=mock_pid_file.return_value,
- stderr=mock_open.return_value,
+ files_preserve=None,
stdout=mock_open.return_value,
+ stderr=mock_open.return_value,
umask=0o077,
),
mock.call.DaemonContext().__enter__(),
@@ -333,11 +334,10 @@ class TestFlowerCommand:
assert mock_setup_locations.mock_calls == [
mock.call(
- log="/tmp/flower.log",
- pid="/tmp/flower.pid",
process="flower",
- stderr="/tmp/flower-stderr.log",
stdout="/tmp/flower-stdout.log",
+ stderr="/tmp/flower-stderr.log",
+ log="/tmp/flower.log",
)
]
mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0],
-1)])
diff --git a/tests/cli/commands/test_kerberos_command.py
b/tests/cli/commands/test_kerberos_command.py
index 41dce045fa..14eb1676bd 100644
--- a/tests/cli/commands/test_kerberos_command.py
+++ b/tests/cli/commands/test_kerberos_command.py
@@ -36,9 +36,9 @@ class TestKerberosCommand:
kerberos_command.kerberos(args)
mock_krb.run.assert_called_once_with(keytab="/tmp/airflow.keytab",
principal="PRINCIPAL")
- @mock.patch("airflow.cli.commands.kerberos_command.TimeoutPIDLockFile")
- @mock.patch("airflow.cli.commands.kerberos_command.setup_locations")
- @mock.patch("airflow.cli.commands.kerberos_command.daemon")
+ @mock.patch("airflow.cli.commands.daemon_utils.TimeoutPIDLockFile")
+ @mock.patch("airflow.cli.commands.daemon_utils.setup_locations")
+ @mock.patch("airflow.cli.commands.daemon_utils.daemon")
@mock.patch("airflow.cli.commands.kerberos_command.krb")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_run_command_daemon(self, mock_krb, mock_daemon,
mock_setup_locations, mock_pid_file):
@@ -66,13 +66,14 @@ class TestKerberosCommand:
]
)
mock_open = mock.mock_open()
- with mock.patch("airflow.cli.commands.kerberos_command.open",
mock_open):
+ with mock.patch("airflow.cli.commands.daemon_utils.open", mock_open):
kerberos_command.kerberos(args)
mock_krb.run.assert_called_once_with(keytab="/tmp/airflow.keytab",
principal="PRINCIPAL")
- assert mock_daemon.mock_calls == [
+ assert mock_daemon.mock_calls[:3] == [
mock.call.DaemonContext(
pidfile=mock_pid_file.return_value,
+ files_preserve=None,
stderr=mock_open.return_value,
stdout=mock_open.return_value,
umask=0o077,
@@ -81,18 +82,14 @@ class TestKerberosCommand:
mock.call.DaemonContext().__exit__(None, None, None),
]
- mock_setup_locations.assert_has_calls(
- [
- mock.call(
- "kerberos",
- "/tmp/kerberos.pid",
- "/tmp/kerberos-stdout.log",
- "/tmp/kerberos-stderr.log",
- "/tmp/kerberos.log",
- )
- ]
+ assert mock_setup_locations.mock_calls[0] == mock.call(
+ process="kerberos",
+ stdout="/tmp/kerberos-stdout.log",
+ stderr="/tmp/kerberos-stderr.log",
+ log="/tmp/kerberos.log",
)
-
mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0],
-1)])
+
+ mock_pid_file.mock_calls[0] =
mock.call(mock_setup_locations.return_value[0], -1)
assert mock_open.mock_calls == [
mock.call(mock_setup_locations.return_value[1], "a"),
mock.call().__enter__(),