You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/03/07 20:46:17 UTC

[airflow] 04/04: Aggressively cache entry points in process (#29625)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 45caa5ec146760c273fe96b05135b8f2e786bee8
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Sat Feb 25 13:36:14 2023 +0800

    Aggressively cache entry points in process (#29625)
    
    (cherry picked from commit 9f51845fdc305e2f5847584e984278c906f9f293)
---
 airflow/providers_manager.py     |  6 ++++--
 airflow/utils/entry_points.py    | 44 +++++++++++++++++++++++-----------------
 tests/conftest.py                |  7 +++++++
 tests/utils/test_entry_points.py |  6 +++---
 4 files changed, 39 insertions(+), 24 deletions(-)

diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 6088e3b373..a0fe51510e 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -31,6 +31,8 @@ from functools import wraps
 from time import perf_counter
 from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
 
+from packaging.utils import canonicalize_name
+
 from airflow.exceptions import AirflowOptionalProviderFeatureException
 from airflow.typing_compat import Literal
 from airflow.utils import yaml
@@ -454,8 +456,8 @@ class ProvidersManager(LoggingMixin):
         and verifies only the subset of fields that are needed at runtime.
         """
         for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
-            package_name = dist.metadata["name"]
-            if self._provider_dict.get(package_name) is not None:
+            package_name = canonicalize_name(dist.metadata["name"])
+            if package_name in self._provider_dict:
                 continue
             log.debug("Loading %s from package %s", entry_point, package_name)
             version = dist.version
diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py
index 41ea38845f..b3f145110f 100644
--- a/airflow/utils/entry_points.py
+++ b/airflow/utils/entry_points.py
@@ -16,10 +16,10 @@
 # under the License.
 from __future__ import annotations
 
+import collections
+import functools
 import logging
-from typing import Iterator
-
-from packaging.utils import canonicalize_name
+from typing import Iterator, Tuple
 
 try:
     import importlib_metadata as metadata
@@ -28,26 +28,32 @@ except ImportError:
 
 log = logging.getLogger(__name__)
 
+EPnD = Tuple[metadata.EntryPoint, metadata.Distribution]
 
-def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]:
-    """Retrieve entry points of the given group.
-
-    This is like the ``entry_points()`` function from importlib.metadata,
-    except it also returns the distribution the entry_point was loaded from.
 
-    :param group: Filter results to only this entrypoint group
-    :return: Generator of (EntryPoint, Distribution) objects for the specified groups
-    """
-    loaded: set[str] = set()
+@functools.lru_cache(maxsize=None)
+def _get_grouped_entry_points() -> dict[str, list[EPnD]]:
+    mapping: dict[str, list[EPnD]] = collections.defaultdict(list)
     for dist in metadata.distributions():
         try:
-            key = canonicalize_name(dist.metadata["Name"])
-            if key in loaded:
-                continue
-            loaded.add(key)
             for e in dist.entry_points:
-                if e.group != group:
-                    continue
-                yield e, dist
+                mapping[e.group].append((e, dist))
         except Exception as e:
             log.warning("Error when retrieving package metadata (skipping it): %s, %s", dist, e)
+    return mapping
+
+
+def entry_points_with_dist(group: str) -> Iterator[EPnD]:
+    """Retrieve entry points of the given group.
+
+    This is like the ``entry_points()`` function from ``importlib.metadata``,
+    except it also returns the distribution the entry point was loaded from.
+
+    Note that this may return multiple distributions to the same package if they
+    are loaded from different ``sys.path`` entries. The caller site should
+    implement appropriate deduplication logic if needed.
+
+    :param group: Filter results to only this entrypoint group
+    :return: Generator of (EntryPoint, Distribution) objects for the specified groups
+    """
+    return iter(_get_grouped_entry_points()[group])
diff --git a/tests/conftest.py b/tests/conftest.py
index 945ece6e67..bdaec7da0f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -856,3 +856,10 @@ def reset_logging_config():
 
     logging_config = import_string(settings.LOGGING_CLASS_PATH)
     logging.config.dictConfig(logging_config)
+
+
+@pytest.fixture(autouse=True)
+def _clear_entry_point_cache():
+    from airflow.utils.entry_points import _get_grouped_entry_points
+
+    _get_grouped_entry_points.cache_clear()
diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py
index de4843dbaa..22537245fc 100644
--- a/tests/utils/test_entry_points.py
+++ b/tests/utils/test_entry_points.py
@@ -45,6 +45,6 @@ class MockMetadata:
 def test_entry_points_with_dist():
     entries = list(entry_points_with_dist("group_x"))
 
-    # The second "dist2" is ignored. Only "group_x" entries are loaded.
-    assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"]
-    assert [ep.name for ep, _ in entries] == ["a", "e"]
+    # Only "group_x" entries are loaded. Distributions are not deduplicated.
+    assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2", "dist2"]
+    assert [ep.name for ep, _ in entries] == ["a", "e", "g"]