This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9f51845fdc Aggressively cache entry points in process (#29625)
9f51845fdc is described below
commit 9f51845fdc305e2f5847584e984278c906f9f293
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Feb 25 13:36:14 2023 +0800
Aggressively cache entry points in process (#29625)
---
airflow/providers_manager.py | 6 ++++--
airflow/utils/entry_points.py | 44 +++++++++++++++++++++++-----------------
tests/conftest.py | 7 +++++++
tests/utils/test_entry_points.py | 6 +++---
4 files changed, 39 insertions(+), 24 deletions(-)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 9dd3a06106..2cab30d89c 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -31,6 +31,8 @@ from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple,
TypeVar, cast
+from packaging.utils import canonicalize_name
+
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.typing_compat import Literal
from airflow.utils import yaml
@@ -467,8 +469,8 @@ class ProvidersManager(LoggingMixin):
and verifies only the subset of fields that are needed at runtime.
"""
for entry_point, dist in
entry_points_with_dist("apache_airflow_provider"):
- package_name = dist.metadata["name"]
- if self._provider_dict.get(package_name) is not None:
+ package_name = canonicalize_name(dist.metadata["name"])
+ if package_name in self._provider_dict:
continue
log.debug("Loading %s from package %s", entry_point, package_name)
version = dist.version
diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py
index 41ea38845f..b3f145110f 100644
--- a/airflow/utils/entry_points.py
+++ b/airflow/utils/entry_points.py
@@ -16,10 +16,10 @@
# under the License.
from __future__ import annotations
+import collections
+import functools
import logging
-from typing import Iterator
-
-from packaging.utils import canonicalize_name
+from typing import Iterator, Tuple
try:
import importlib_metadata as metadata
@@ -28,26 +28,32 @@ except ImportError:
log = logging.getLogger(__name__)
+EPnD = Tuple[metadata.EntryPoint, metadata.Distribution]
-def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint,
metadata.Distribution]]:
- """Retrieve entry points of the given group.
-
- This is like the ``entry_points()`` function from importlib.metadata,
- except it also returns the distribution the entry_point was loaded from.
- :param group: Filter results to only this entrypoint group
- :return: Generator of (EntryPoint, Distribution) objects for the specified
groups
- """
- loaded: set[str] = set()
[email protected]_cache(maxsize=None)
+def _get_grouped_entry_points() -> dict[str, list[EPnD]]:
+ mapping: dict[str, list[EPnD]] = collections.defaultdict(list)
for dist in metadata.distributions():
try:
- key = canonicalize_name(dist.metadata["Name"])
- if key in loaded:
- continue
- loaded.add(key)
for e in dist.entry_points:
- if e.group != group:
- continue
- yield e, dist
+ mapping[e.group].append((e, dist))
except Exception as e:
log.warning("Error when retrieving package metadata (skipping it):
%s, %s", dist, e)
+ return mapping
+
+
+def entry_points_with_dist(group: str) -> Iterator[EPnD]:
+ """Retrieve entry points of the given group.
+
+ This is like the ``entry_points()`` function from ``importlib.metadata``,
+ except it also returns the distribution the entry point was loaded from.
+
+ Note that this may return multiple distributions to the same package if
they
+ are loaded from different ``sys.path`` entries. The caller site should
+ implement appropriate deduplication logic if needed.
+
+ :param group: Filter results to only this entrypoint group
+ :return: Generator of (EntryPoint, Distribution) objects for the specified
groups
+ """
+ return iter(_get_grouped_entry_points()[group])
diff --git a/tests/conftest.py b/tests/conftest.py
index 509fc9b33d..e38f86571b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -889,3 +889,10 @@ def _clear_db(request):
exc_name_parts.insert(0, exc_module)
extra_msg = "" if request.config.option.db_init else ", try to run
with flag --with-db-init"
pytest.exit(f"Unable clear test DB{extra_msg}, got error
{'.'.join(exc_name_parts)}: {ex}")
+
+
[email protected](autouse=True)
+def _clear_entry_point_cache():
+ from airflow.utils.entry_points import _get_grouped_entry_points
+
+ _get_grouped_entry_points.cache_clear()
diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py
index de4843dbaa..22537245fc 100644
--- a/tests/utils/test_entry_points.py
+++ b/tests/utils/test_entry_points.py
@@ -45,6 +45,6 @@ class MockMetadata:
def test_entry_points_with_dist():
entries = list(entry_points_with_dist("group_x"))
- # The second "dist2" is ignored. Only "group_x" entries are loaded.
- assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"]
- assert [ep.name for ep, _ in entries] == ["a", "e"]
+ # Only "group_x" entries are loaded. Distributions are not deduplicated.
+ assert [dist.metadata["Name"] for _, dist in entries] == ["dist1",
"Dist2", "dist2"]
+ assert [ep.name for ep, _ in entries] == ["a", "e", "g"]