You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/03/07 20:46:17 UTC
[airflow] 04/04: Aggressively cache entry points in process (#29625)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 45caa5ec146760c273fe96b05135b8f2e786bee8
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Sat Feb 25 13:36:14 2023 +0800
Aggressively cache entry points in process (#29625)
(cherry picked from commit 9f51845fdc305e2f5847584e984278c906f9f293)
---
airflow/providers_manager.py | 6 ++++--
airflow/utils/entry_points.py | 44 +++++++++++++++++++++++-----------------
tests/conftest.py | 7 +++++++
tests/utils/test_entry_points.py | 6 +++---
4 files changed, 39 insertions(+), 24 deletions(-)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 6088e3b373..a0fe51510e 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -31,6 +31,8 @@ from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
+from packaging.utils import canonicalize_name
+
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.typing_compat import Literal
from airflow.utils import yaml
@@ -454,8 +456,8 @@ class ProvidersManager(LoggingMixin):
and verifies only the subset of fields that are needed at runtime.
"""
for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
- package_name = dist.metadata["name"]
- if self._provider_dict.get(package_name) is not None:
+ package_name = canonicalize_name(dist.metadata["name"])
+ if package_name in self._provider_dict:
continue
log.debug("Loading %s from package %s", entry_point, package_name)
version = dist.version
diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py
index 41ea38845f..b3f145110f 100644
--- a/airflow/utils/entry_points.py
+++ b/airflow/utils/entry_points.py
@@ -16,10 +16,10 @@
# under the License.
from __future__ import annotations
+import collections
+import functools
import logging
-from typing import Iterator
-
-from packaging.utils import canonicalize_name
+from typing import Iterator, Tuple
try:
import importlib_metadata as metadata
@@ -28,26 +28,32 @@ except ImportError:
log = logging.getLogger(__name__)
+EPnD = Tuple[metadata.EntryPoint, metadata.Distribution]
-def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]:
- """Retrieve entry points of the given group.
-
- This is like the ``entry_points()`` function from importlib.metadata,
- except it also returns the distribution the entry_point was loaded from.
- :param group: Filter results to only this entrypoint group
- :return: Generator of (EntryPoint, Distribution) objects for the specified groups
- """
- loaded: set[str] = set()
+@functools.lru_cache(maxsize=None)
+def _get_grouped_entry_points() -> dict[str, list[EPnD]]:
+ mapping: dict[str, list[EPnD]] = collections.defaultdict(list)
for dist in metadata.distributions():
try:
- key = canonicalize_name(dist.metadata["Name"])
- if key in loaded:
- continue
- loaded.add(key)
for e in dist.entry_points:
- if e.group != group:
- continue
- yield e, dist
+ mapping[e.group].append((e, dist))
except Exception as e:
log.warning("Error when retrieving package metadata (skipping it): %s, %s", dist, e)
+ return mapping
+
+
+def entry_points_with_dist(group: str) -> Iterator[EPnD]:
+ """Retrieve entry points of the given group.
+
+ This is like the ``entry_points()`` function from ``importlib.metadata``,
+ except it also returns the distribution the entry point was loaded from.
+
+ Note that this may return multiple distributions to the same package if they
+ are loaded from different ``sys.path`` entries. The caller site should
+ implement appropriate deduplication logic if needed.
+
+ :param group: Filter results to only this entrypoint group
+ :return: Generator of (EntryPoint, Distribution) objects for the specified groups
+ """
+ return iter(_get_grouped_entry_points()[group])
diff --git a/tests/conftest.py b/tests/conftest.py
index 945ece6e67..bdaec7da0f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -856,3 +856,10 @@ def reset_logging_config():
logging_config = import_string(settings.LOGGING_CLASS_PATH)
logging.config.dictConfig(logging_config)
+
+
+@pytest.fixture(autouse=True)
+def _clear_entry_point_cache():
+ from airflow.utils.entry_points import _get_grouped_entry_points
+
+ _get_grouped_entry_points.cache_clear()
diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py
index de4843dbaa..22537245fc 100644
--- a/tests/utils/test_entry_points.py
+++ b/tests/utils/test_entry_points.py
@@ -45,6 +45,6 @@ class MockMetadata:
def test_entry_points_with_dist():
entries = list(entry_points_with_dist("group_x"))
- # The second "dist2" is ignored. Only "group_x" entries are loaded.
- assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"]
- assert [ep.name for ep, _ in entries] == ["a", "e"]
+ # Only "group_x" entries are loaded. Distributions are not deduplicated.
+ assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2", "dist2"]
+ assert [ep.name for ep, _ in entries] == ["a", "e", "g"]