This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 34edb6508d Add protection against accidental Providers Manager 
initialization (#32694)
34edb6508d is described below

commit 34edb6508d0e5d1f185abf400b9572139fb12b16
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Jul 19 18:13:38 2023 +0200

    Add protection against accidental Providers Manager initialization (#32694)
    
    ProvidersManager might get accidentally initialized during CLI
    argument parsing, and if it does, then CLI and arg completion
    takes huge performance hit - it takes seconds to intialize
    ProvidersManager when you have many providers packages installed,
    so initializing ProvidersManager should not happen just when you
    parse arguments of the CLIs but only when you execute commands.
    
    It's rather easy to trigger ProvidersManager by just importing
    some package (especially as we move executors to providers).
    
    This PR adds a new CLI command in providers to check status of
    initialization - this command will fail when provider's initlalization
    is executed during CLI parsing. It will also show the stack
    trace of where ProvidersManager has been initialized from.
    
    Test is added to detect such situation by running the new command
    in a separate subprocess, thus making sure ProvidersManager has not
    been initialized in another test.
---
 airflow/cli/cli_config.py                          |  7 ++++++
 airflow/cli/commands/celery_command.py             |  7 +++++-
 airflow/cli/commands/provider_command.py           | 27 ++++++++++++++++++++++
 airflow/providers_manager.py                       | 14 +++++++++++
 tests/cli/commands/test_celery_command.py          | 12 +++++-----
 tests/cli/test_cli_parser.py                       | 10 ++++++++
 .../cli/commands/test_celery_command.py            | 14 ++++++-----
 7 files changed, 78 insertions(+), 13 deletions(-)

diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index 97aa83c100..66b9590513 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -1820,8 +1820,15 @@ PROVIDERS_COMMANDS = (
         
func=lazy_load_command("airflow.cli.commands.provider_command.executors_list"),
         args=(ARG_OUTPUT, ARG_VERBOSE),
     ),
+    ActionCommand(
+        name="status",
+        help="Get information about provider initialization status",
+        func=lazy_load_command("airflow.cli.commands.provider_command.status"),
+        args=(ARG_VERBOSE,),
+    ),
 )
 
+
 USERS_COMMANDS = (
     ActionCommand(
         name="list",
diff --git a/airflow/cli/commands/celery_command.py 
b/airflow/cli/commands/celery_command.py
index 6d3dbe64bb..7adbea36e6 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -34,7 +34,6 @@ from lockfile.pidlockfile import read_pid_from_pidfile, 
remove_existing_pidfile
 
 from airflow import settings
 from airflow.configuration import conf
-from airflow.providers.celery.executors.celery_executor import app as 
celery_app
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import setup_locations, setup_logging
 from airflow.utils.serve_logs import serve_logs
@@ -45,6 +44,9 @@ WORKER_PROCESS_NAME = "worker"
 @cli_utils.action_cli
 def flower(args):
     """Starts Flower, Celery monitoring tool."""
+    # This needs to be imported locally to not trigger Providers Manager 
initialization
+    from airflow.providers.celery.executors.celery_executor import app as 
celery_app
+
     options = [
         "flower",
         conf.get("celery", "BROKER_URL"),
@@ -132,6 +134,9 @@ def logger_setup_handler(logger, **kwargs):
 @cli_utils.action_cli
 def worker(args):
     """Starts Airflow Celery worker."""
+    # This needs to be imported locally to not trigger Providers Manager 
initialization
+    from airflow.providers.celery.executors.celery_executor import app as 
celery_app
+
     # Disable connection pool so that celery worker does not hold an 
unnecessary db connection
     settings.reconfigure_orm(disable_connection_pool=True)
     if not settings.validate_session():
diff --git a/airflow/cli/commands/provider_command.py 
b/airflow/cli/commands/provider_command.py
index 67c78c57fe..0256a74089 100644
--- a/airflow/cli/commands/provider_command.py
+++ b/airflow/cli/commands/provider_command.py
@@ -17,6 +17,8 @@
 """Providers sub-commands."""
 from __future__ import annotations
 
+import sys
+
 import re2
 
 from airflow.cli.simple_table import AirflowConsole
@@ -179,3 +181,28 @@ def executors_list(args):
             "executor_class_names": x,
         },
     )
+
+
+@suppress_logs_and_warning
+def status(args):
+    """Informs if providers manager has been initialized.
+
+    If provider is initialized, shows the stack trace and exit with error code 
1.
+    """
+    import rich
+
+    if ProvidersManager.initialized():
+        rich.print(
+            "\n[red]ProvidersManager was initialized during CLI parsing. This 
should not happen.\n",
+            file=sys.stderr,
+        )
+        rich.print(
+            "\n[yellow]Please make sure no Providers Manager initialization 
happens during CLI parsing.\n",
+            file=sys.stderr,
+        )
+        rich.print("Stack trace where it has been initialized:\n", 
file=sys.stderr)
+        rich.print(ProvidersManager.initialization_stack_trace(), 
file=sys.stderr)
+        sys.exit(1)
+    else:
+        rich.print("[green]All ok. Providers Manager was not initialized 
during the CLI parsing.")
+        sys.exit(0)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 4f2f063599..108712d484 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -20,10 +20,12 @@ from __future__ import annotations
 
 import fnmatch
 import functools
+import inspect
 import json
 import logging
 import os
 import sys
+import traceback
 import warnings
 from collections import OrderedDict
 from dataclasses import dataclass
@@ -377,10 +379,22 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
     """
 
     resource_version = "0"
+    _initialized: bool = False
+    _initialization_stack_trace = None
+
+    @staticmethod
+    def initialized() -> bool:
+        return ProvidersManager._initialized
+
+    @staticmethod
+    def initialization_stack_trace() -> str:
+        return ProvidersManager._initialization_stack_trace
 
     def __init__(self):
         """Initializes the manager."""
         super().__init__()
+        ProvidersManager._initialized = True
+        ProvidersManager._initialization_stack_trace = 
"".join(traceback.format_stack(inspect.currentframe()))
         self._initialized_cache: dict[str, bool] = {}
         # Keeps dict of providers keyed by module name
         self._provider_dict: dict[str, ProviderInfo] = {}
diff --git a/tests/cli/commands/test_celery_command.py 
b/tests/cli/commands/test_celery_command.py
index 9acea81c60..0e8dee2457 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -88,7 +88,7 @@ class TestCeleryStopCommand:
                 mock_process.return_value.terminate.assert_called_once_with()
 
     @mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
+    @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     @mock.patch("airflow.cli.commands.celery_command.setup_locations")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
     def test_same_pid_file_is_used_in_start_and_stop(
@@ -113,7 +113,7 @@ class TestCeleryStopCommand:
 
     @mock.patch("airflow.cli.commands.celery_command.remove_existing_pidfile")
     @mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
+    @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     @mock.patch("airflow.cli.commands.celery_command.psutil.Process")
     @mock.patch("airflow.cli.commands.celery_command.setup_locations")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
@@ -151,7 +151,7 @@ class TestWorkerStart:
 
     @mock.patch("airflow.cli.commands.celery_command.setup_locations")
     @mock.patch("airflow.cli.commands.celery_command.Process")
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
+    @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
     def test_worker_started_with_required_arguments(self, mock_celery_app, 
mock_popen, mock_locations):
         pid_file = "pid_file"
@@ -211,7 +211,7 @@ class TestWorkerFailure:
         cls.parser = cli_parser.get_parser()
 
     @mock.patch("airflow.cli.commands.celery_command.Process")
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
+    @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
     def test_worker_failure_gracefull_shutdown(self, mock_celery_app, 
mock_popen):
         args = self.parser.parse_args(["celery", "worker"])
@@ -228,7 +228,7 @@ class TestFlowerCommand:
     def setup_class(cls):
         cls.parser = cli_parser.get_parser()
 
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
+    @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
     def test_run_command(self, mock_celery_app):
         args = self.parser.parse_args(
@@ -267,7 +267,7 @@ class TestFlowerCommand:
     @mock.patch("airflow.cli.commands.celery_command.TimeoutPIDLockFile")
     @mock.patch("airflow.cli.commands.celery_command.setup_locations")
     @mock.patch("airflow.cli.commands.celery_command.daemon")
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
+    @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
     def test_run_command_daemon(self, mock_celery_app, mock_daemon, 
mock_setup_locations, mock_pid_file):
         mock_setup_locations.return_value = (
diff --git a/tests/cli/test_cli_parser.py b/tests/cli/test_cli_parser.py
index 935417b4b6..378ebd4e71 100644
--- a/tests/cli/test_cli_parser.py
+++ b/tests/cli/test_cli_parser.py
@@ -22,6 +22,7 @@ import argparse
 import contextlib
 import io
 import re
+import subprocess
 import timeit
 from collections import Counter
 from unittest.mock import patch
@@ -279,3 +280,12 @@ class TestCli:
         timing_result = timeit.timeit(stmt=timing_code, number=num_samples, 
setup=setup_code) / num_samples
         # Average run time of Airflow CLI should at least be within 3.5s
         assert timing_result < threshold
+
+    def test_cli_parsing_does_not_initialize_providers_manager(self):
+        """Test that CLI parsing does not initialize providers manager.
+
+        This test is here to make sure that we do not initialize providers 
manager - it is run as a
+        separate subprocess, to make sure we do not have providers manager 
initialized in the main
+        process from other tests.
+        """
+        subprocess.check_call(["airflow", "providers", "status"])
diff --git a/tests/integration/cli/commands/test_celery_command.py 
b/tests/integration/cli/commands/test_celery_command.py
index 306d5fcfd8..b421c949f0 100644
--- a/tests/integration/cli/commands/test_celery_command.py
+++ b/tests/integration/cli/commands/test_celery_command.py
@@ -33,10 +33,11 @@ class TestWorkerServeLogs:
     def setup_class(cls):
         cls.parser = cli_parser.get_parser()
 
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
-    def test_serve_logs_on_worker_start(self, mock_celery_app):
-        with mock.patch("airflow.cli.commands.celery_command.Process") as 
mock_process:
+    def test_serve_logs_on_worker_start(self):
+        with mock.patch("airflow.cli.commands.celery_command.Process") as 
mock_process, mock.patch(
+            "airflow.providers.celery.executors.celery_executor.app"
+        ):
             args = self.parser.parse_args(["celery", "worker", 
"--concurrency", "1"])
 
             with mock.patch("celery.platforms.check_privileges") as 
mock_privil:
@@ -44,10 +45,11 @@ class TestWorkerServeLogs:
                 celery_command.worker(args)
                 mock_process.assert_called()
 
-    @mock.patch("airflow.cli.commands.celery_command.celery_app")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
-    def test_skip_serve_logs_on_worker_start(self, mock_celery_app):
-        with mock.patch("airflow.cli.commands.celery_command.Process") as 
mock_popen:
+    def test_skip_serve_logs_on_worker_start(self):
+        with mock.patch("airflow.cli.commands.celery_command.Process") as 
mock_popen, mock.patch(
+            "airflow.providers.celery.executors.celery_executor.app"
+        ):
             args = self.parser.parse_args(["celery", "worker", 
"--concurrency", "1", "--skip-serve-logs"])
 
             with mock.patch("celery.platforms.check_privileges") as 
mock_privil:

Reply via email to