This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 34edb6508d Add protection against accidental Providers Manager
initialization (#32694)
34edb6508d is described below
commit 34edb6508d0e5d1f185abf400b9572139fb12b16
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Jul 19 18:13:38 2023 +0200
Add protection against accidental Providers Manager initialization (#32694)
ProvidersManager might get accidentally initialized during CLI
argument parsing, and if it does, then CLI and arg completion
takes huge performance hit - it takes seconds to intialize
ProvidersManager when you have many providers packages installed,
so initializing ProvidersManager should not happen just when you
parse arguments of the CLIs but only when you execute commands.
It's rather easy to trigger ProvidersManager by just importing
some package (especially as we move executors to providers).
This PR adds a new CLI command in providers to check status of
initialization - this command will fail when provider's initlalization
is executed during CLI parsing. It will also show the stack
trace of where ProvidersManager has been initialized from.
Test is added to detect such situation by running the new command
in a separate subprocess, thus making sure ProvidersManager has not
been initialized in another test.
---
airflow/cli/cli_config.py | 7 ++++++
airflow/cli/commands/celery_command.py | 7 +++++-
airflow/cli/commands/provider_command.py | 27 ++++++++++++++++++++++
airflow/providers_manager.py | 14 +++++++++++
tests/cli/commands/test_celery_command.py | 12 +++++-----
tests/cli/test_cli_parser.py | 10 ++++++++
.../cli/commands/test_celery_command.py | 14 ++++++-----
7 files changed, 78 insertions(+), 13 deletions(-)
diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index 97aa83c100..66b9590513 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -1820,8 +1820,15 @@ PROVIDERS_COMMANDS = (
func=lazy_load_command("airflow.cli.commands.provider_command.executors_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
+ ActionCommand(
+ name="status",
+ help="Get information about provider initialization status",
+ func=lazy_load_command("airflow.cli.commands.provider_command.status"),
+ args=(ARG_VERBOSE,),
+ ),
)
+
USERS_COMMANDS = (
ActionCommand(
name="list",
diff --git a/airflow/cli/commands/celery_command.py
b/airflow/cli/commands/celery_command.py
index 6d3dbe64bb..7adbea36e6 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -34,7 +34,6 @@ from lockfile.pidlockfile import read_pid_from_pidfile,
remove_existing_pidfile
from airflow import settings
from airflow.configuration import conf
-from airflow.providers.celery.executors.celery_executor import app as
celery_app
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.serve_logs import serve_logs
@@ -45,6 +44,9 @@ WORKER_PROCESS_NAME = "worker"
@cli_utils.action_cli
def flower(args):
"""Starts Flower, Celery monitoring tool."""
+ # This needs to be imported locally to not trigger Providers Manager
initialization
+ from airflow.providers.celery.executors.celery_executor import app as
celery_app
+
options = [
"flower",
conf.get("celery", "BROKER_URL"),
@@ -132,6 +134,9 @@ def logger_setup_handler(logger, **kwargs):
@cli_utils.action_cli
def worker(args):
"""Starts Airflow Celery worker."""
+ # This needs to be imported locally to not trigger Providers Manager
initialization
+ from airflow.providers.celery.executors.celery_executor import app as
celery_app
+
# Disable connection pool so that celery worker does not hold an
unnecessary db connection
settings.reconfigure_orm(disable_connection_pool=True)
if not settings.validate_session():
diff --git a/airflow/cli/commands/provider_command.py
b/airflow/cli/commands/provider_command.py
index 67c78c57fe..0256a74089 100644
--- a/airflow/cli/commands/provider_command.py
+++ b/airflow/cli/commands/provider_command.py
@@ -17,6 +17,8 @@
"""Providers sub-commands."""
from __future__ import annotations
+import sys
+
import re2
from airflow.cli.simple_table import AirflowConsole
@@ -179,3 +181,28 @@ def executors_list(args):
"executor_class_names": x,
},
)
+
+
+@suppress_logs_and_warning
+def status(args):
+ """Informs if providers manager has been initialized.
+
+ If provider is initialized, shows the stack trace and exit with error code
1.
+ """
+ import rich
+
+ if ProvidersManager.initialized():
+ rich.print(
+ "\n[red]ProvidersManager was initialized during CLI parsing. This
should not happen.\n",
+ file=sys.stderr,
+ )
+ rich.print(
+ "\n[yellow]Please make sure no Providers Manager initialization
happens during CLI parsing.\n",
+ file=sys.stderr,
+ )
+ rich.print("Stack trace where it has been initialized:\n",
file=sys.stderr)
+ rich.print(ProvidersManager.initialization_stack_trace(),
file=sys.stderr)
+ sys.exit(1)
+ else:
+ rich.print("[green]All ok. Providers Manager was not initialized
during the CLI parsing.")
+ sys.exit(0)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 4f2f063599..108712d484 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -20,10 +20,12 @@ from __future__ import annotations
import fnmatch
import functools
+import inspect
import json
import logging
import os
import sys
+import traceback
import warnings
from collections import OrderedDict
from dataclasses import dataclass
@@ -377,10 +379,22 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
"""
resource_version = "0"
+ _initialized: bool = False
+ _initialization_stack_trace = None
+
+ @staticmethod
+ def initialized() -> bool:
+ return ProvidersManager._initialized
+
+ @staticmethod
+ def initialization_stack_trace() -> str:
+ return ProvidersManager._initialization_stack_trace
def __init__(self):
"""Initializes the manager."""
super().__init__()
+ ProvidersManager._initialized = True
+ ProvidersManager._initialization_stack_trace =
"".join(traceback.format_stack(inspect.currentframe()))
self._initialized_cache: dict[str, bool] = {}
# Keeps dict of providers keyed by module name
self._provider_dict: dict[str, ProviderInfo] = {}
diff --git a/tests/cli/commands/test_celery_command.py
b/tests/cli/commands/test_celery_command.py
index 9acea81c60..0e8dee2457 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -88,7 +88,7 @@ class TestCeleryStopCommand:
mock_process.return_value.terminate.assert_called_once_with()
@mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
+ @mock.patch("airflow.providers.celery.executors.celery_executor.app")
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_same_pid_file_is_used_in_start_and_stop(
@@ -113,7 +113,7 @@ class TestCeleryStopCommand:
@mock.patch("airflow.cli.commands.celery_command.remove_existing_pidfile")
@mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
+ @mock.patch("airflow.providers.celery.executors.celery_executor.app")
@mock.patch("airflow.cli.commands.celery_command.psutil.Process")
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@conf_vars({("core", "executor"): "CeleryExecutor"})
@@ -151,7 +151,7 @@ class TestWorkerStart:
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@mock.patch("airflow.cli.commands.celery_command.Process")
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
+ @mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_worker_started_with_required_arguments(self, mock_celery_app,
mock_popen, mock_locations):
pid_file = "pid_file"
@@ -211,7 +211,7 @@ class TestWorkerFailure:
cls.parser = cli_parser.get_parser()
@mock.patch("airflow.cli.commands.celery_command.Process")
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
+ @mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_worker_failure_gracefull_shutdown(self, mock_celery_app,
mock_popen):
args = self.parser.parse_args(["celery", "worker"])
@@ -228,7 +228,7 @@ class TestFlowerCommand:
def setup_class(cls):
cls.parser = cli_parser.get_parser()
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
+ @mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_run_command(self, mock_celery_app):
args = self.parser.parse_args(
@@ -267,7 +267,7 @@ class TestFlowerCommand:
@mock.patch("airflow.cli.commands.celery_command.TimeoutPIDLockFile")
@mock.patch("airflow.cli.commands.celery_command.setup_locations")
@mock.patch("airflow.cli.commands.celery_command.daemon")
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
+ @mock.patch("airflow.providers.celery.executors.celery_executor.app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
def test_run_command_daemon(self, mock_celery_app, mock_daemon,
mock_setup_locations, mock_pid_file):
mock_setup_locations.return_value = (
diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py
index 935417b4b6..378ebd4e71 100644
--- a/tests/cli/test_cli_parser.py
+++ b/tests/cli/test_cli_parser.py
@@ -22,6 +22,7 @@ import argparse
import contextlib
import io
import re
+import subprocess
import timeit
from collections import Counter
from unittest.mock import patch
@@ -279,3 +280,12 @@ class TestCli:
timing_result = timeit.timeit(stmt=timing_code, number=num_samples,
setup=setup_code) / num_samples
# Average run time of Airflow CLI should at least be within 3.5s
assert timing_result < threshold
+
+ def test_cli_parsing_does_not_initialize_providers_manager(self):
+ """Test that CLI parsing does not initialize providers manager.
+
+ This test is here to make sure that we do not initialize providers
manager - it is run as a
+ separate subprocess, to make sure we do not have providers manager
initialized in the main
+ process from other tests.
+ """
+ subprocess.check_call(["airflow", "providers", "status"])
diff --git a/tests/integration/cli/commands/test_celery_command.py
b/tests/integration/cli/commands/test_celery_command.py
index 306d5fcfd8..b421c949f0 100644
--- a/tests/integration/cli/commands/test_celery_command.py
+++ b/tests/integration/cli/commands/test_celery_command.py
@@ -33,10 +33,11 @@ class TestWorkerServeLogs:
def setup_class(cls):
cls.parser = cli_parser.get_parser()
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
- def test_serve_logs_on_worker_start(self, mock_celery_app):
- with mock.patch("airflow.cli.commands.celery_command.Process") as
mock_process:
+ def test_serve_logs_on_worker_start(self):
+ with mock.patch("airflow.cli.commands.celery_command.Process") as
mock_process, mock.patch(
+ "airflow.providers.celery.executors.celery_executor.app"
+ ):
args = self.parser.parse_args(["celery", "worker",
"--concurrency", "1"])
with mock.patch("celery.platforms.check_privileges") as
mock_privil:
@@ -44,10 +45,11 @@ class TestWorkerServeLogs:
celery_command.worker(args)
mock_process.assert_called()
- @mock.patch("airflow.cli.commands.celery_command.celery_app")
@conf_vars({("core", "executor"): "CeleryExecutor"})
- def test_skip_serve_logs_on_worker_start(self, mock_celery_app):
- with mock.patch("airflow.cli.commands.celery_command.Process") as
mock_popen:
+ def test_skip_serve_logs_on_worker_start(self):
+ with mock.patch("airflow.cli.commands.celery_command.Process") as
mock_popen, mock.patch(
+ "airflow.providers.celery.executors.celery_executor.app"
+ ):
args = self.parser.parse_args(["celery", "worker",
"--concurrency", "1", "--skip-serve-logs"])
with mock.patch("celery.platforms.check_privileges") as
mock_privil: