This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-5-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 45caa5ec146760c273fe96b05135b8f2e786bee8 Author: Tzu-ping Chung <[email protected]> AuthorDate: Sat Feb 25 13:36:14 2023 +0800 Aggressively cache entry points in process (#29625) (cherry picked from commit 9f51845fdc305e2f5847584e984278c906f9f293) --- airflow/providers_manager.py | 6 ++++-- airflow/utils/entry_points.py | 44 +++++++++++++++++++++++----------------- tests/conftest.py | 7 +++++++ tests/utils/test_entry_points.py | 6 +++--- 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 6088e3b373..a0fe51510e 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -31,6 +31,8 @@ from functools import wraps from time import perf_counter from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast +from packaging.utils import canonicalize_name + from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.typing_compat import Literal from airflow.utils import yaml @@ -454,8 +456,8 @@ class ProvidersManager(LoggingMixin): and verifies only the subset of fields that are needed at runtime. """ for entry_point, dist in entry_points_with_dist("apache_airflow_provider"): - package_name = dist.metadata["name"] - if self._provider_dict.get(package_name) is not None: + package_name = canonicalize_name(dist.metadata["name"]) + if package_name in self._provider_dict: continue log.debug("Loading %s from package %s", entry_point, package_name) version = dist.version diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py index 41ea38845f..b3f145110f 100644 --- a/airflow/utils/entry_points.py +++ b/airflow/utils/entry_points.py @@ -16,10 +16,10 @@ # under the License. from __future__ import annotations +import collections +import functools import logging -from typing import Iterator - -from packaging.utils import canonicalize_name +from typing import Iterator, Tuple try: import importlib_metadata as metadata @@ -28,26 +28,32 @@ except ImportError: log = logging.getLogger(__name__) +EPnD = Tuple[metadata.EntryPoint, metadata.Distribution] -def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]: - """Retrieve entry points of the given group. - - This is like the ``entry_points()`` function from importlib.metadata, - except it also returns the distribution the entry_point was loaded from. - :param group: Filter results to only this entrypoint group - :return: Generator of (EntryPoint, Distribution) objects for the specified groups - """ - loaded: set[str] = set() [email protected]_cache(maxsize=None) +def _get_grouped_entry_points() -> dict[str, list[EPnD]]: + mapping: dict[str, list[EPnD]] = collections.defaultdict(list) for dist in metadata.distributions(): try: - key = canonicalize_name(dist.metadata["Name"]) - if key in loaded: - continue - loaded.add(key) for e in dist.entry_points: - if e.group != group: - continue - yield e, dist + mapping[e.group].append((e, dist)) except Exception as e: log.warning("Error when retrieving package metadata (skipping it): %s, %s", dist, e) + return mapping + + +def entry_points_with_dist(group: str) -> Iterator[EPnD]: + """Retrieve entry points of the given group. + + This is like the ``entry_points()`` function from ``importlib.metadata``, + except it also returns the distribution the entry point was loaded from. + + Note that this may return multiple distributions to the same package if they + are loaded from different ``sys.path`` entries. The caller site should + implement appropriate deduplication logic if needed. + + :param group: Filter results to only this entrypoint group + :return: Generator of (EntryPoint, Distribution) objects for the specified groups + """ + return iter(_get_grouped_entry_points()[group]) diff --git a/tests/conftest.py b/tests/conftest.py index 945ece6e67..bdaec7da0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -856,3 +856,10 @@ def reset_logging_config(): logging_config = import_string(settings.LOGGING_CLASS_PATH) logging.config.dictConfig(logging_config) + + [email protected](autouse=True) +def _clear_entry_point_cache(): + from airflow.utils.entry_points import _get_grouped_entry_points + + _get_grouped_entry_points.cache_clear() diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py index de4843dbaa..22537245fc 100644 --- a/tests/utils/test_entry_points.py +++ b/tests/utils/test_entry_points.py @@ -45,6 +45,6 @@ class MockMetadata: def test_entry_points_with_dist(): entries = list(entry_points_with_dist("group_x")) - # The second "dist2" is ignored. Only "group_x" entries are loaded. - assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"] - assert [ep.name for ep, _ in entries] == ["a", "e"] + # Only "group_x" entries are loaded. Distributions are not deduplicated. + assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2", "dist2"] + assert [ep.name for ep, _ in entries] == ["a", "e", "g"]
