You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/08/16 15:35:17 UTC
[airflow] 10/11: Only load distribution of a name once (#25296)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to tag v2.3.3+astro.2
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 966cc8e43a96f367084170cc19f108f8d24bf55e
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Aug 3 21:01:43 2022 +0800
Only load distribution of a name once (#25296)
(cherry picked from commit c30dc5e64d7229cbf8e9fbe84cfa790dfef5fb8c)
---
airflow/plugins_manager.py | 2 +-
airflow/utils/entry_points.py | 22 +++++++++++-----
tests/plugins/test_plugins_manager.py | 3 ++-
tests/utils/test_entry_points.py | 49 +++++++++++++++++++++++++++++++++++
tests/www/views/test_views.py | 2 +-
5 files changed, 69 insertions(+), 9 deletions(-)
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 82e295fa19..431d5fe55a 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -113,7 +113,7 @@ class EntryPointSource(AirflowPluginSource):
"""Class used to define Plugins loaded from entrypoint."""
def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
- self.dist = dist.metadata['name']
+ self.dist = dist.metadata['Name']
self.version = dist.version
self.entrypoint = str(entrypoint)
diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py
index 668ed9b994..483f9efe77 100644
--- a/airflow/utils/entry_points.py
+++ b/airflow/utils/entry_points.py
@@ -15,15 +15,20 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+from typing import Iterator
+
+from packaging.utils import canonicalize_name
+
try:
- import importlib_metadata
+ import importlib_metadata as metadata
except ImportError:
- from importlib import metadata as importlib_metadata # type: ignore
+ from importlib import metadata # type: ignore[no-redef]
-def entry_points_with_dist(group: str):
- """
- Return EntryPoint objects of the given group, along with the distribution information.
+def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]:
+ """Retrieve entry points of the given group.
This is like the ``entry_points()`` function from importlib.metadata,
except it also returns the distribution the entry_point was loaded from.
@@ -31,7 +36,12 @@ def entry_points_with_dist(group: str):
:param group: Filter results to only this entrypoint group
:return: Generator of (EntryPoint, Distribution) objects for the specified groups
"""
- for dist in importlib_metadata.distributions():
+ loaded: set[str] = set()
+ for dist in metadata.distributions():
+ key = canonicalize_name(dist.metadata["Name"])
+ if key in loaded:
+ continue
+ loaded.add(key)
for e in dist.entry_points:
if e.group != group:
continue
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index f97a811c91..c46b6e83f2 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -298,6 +298,7 @@ class TestPluginsManager:
from airflow.plugins_manager import import_errors, load_entrypoint_plugins
mock_dist = mock.Mock()
+ mock_dist.metadata = {"Name": "test-dist"}
mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint'
@@ -387,7 +388,7 @@ class TestEntryPointSource:
mock_entrypoint.module = 'module_name_plugin'
mock_dist = mock.Mock()
- mock_dist.metadata = {'name': 'test-entrypoint-plugin'}
+ mock_dist.metadata = {'Name': 'test-entrypoint-plugin'}
mock_dist.version = '1.0.0'
mock_dist.entry_points = [mock_entrypoint]
diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py
new file mode 100644
index 0000000000..65f688647e
--- /dev/null
+++ b/tests/utils/test_entry_points.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Iterable
+from unittest import mock
+
+from airflow.utils.entry_points import entry_points_with_dist, metadata
+
+
+class MockDistribution:
+ def __init__(self, name: str, entry_points: Iterable[metadata.EntryPoint]) -> None:
+ self.metadata = {"Name": name}
+ self.entry_points = entry_points
+
+
+class MockMetadata:
+ def distributions(self):
+ return [
+ MockDistribution(
+ "dist1",
+ [metadata.EntryPoint("a", "b", "group_x"), metadata.EntryPoint("c", "d", "group_y")],
+ ),
+ MockDistribution("Dist2", [metadata.EntryPoint("e", "f", "group_x")]),
+ MockDistribution("dist2", [metadata.EntryPoint("g", "h", "group_x")]), # Duplicated name.
+ ]
+
+
+@mock.patch("airflow.utils.entry_points.metadata", MockMetadata())
+def test_entry_points_with_dist():
+ entries = list(entry_points_with_dist("group_x"))
+
+ # The second "dist2" is ignored. Only "group_x" entries are loaded.
+ assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"]
+ assert [ep.name for ep, _ in entries] == ["a", "e"]
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index fa79e145cb..b6d21a3f26 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -87,7 +87,7 @@ def test_plugin_should_list_entrypoint_on_page_with_details(admin_client):
mock_plugin = AirflowPlugin()
mock_plugin.name = "test_plugin"
mock_plugin.source = EntryPointSource(
- mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'})
+ mock.Mock(), mock.Mock(version='1.0.0', metadata={'Name': 'test-entrypoint-testpluginview'})
)
with mock_plugin_manager(plugins=[mock_plugin]):
resp = admin_client.get('/plugin')