You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/08/16 15:35:17 UTC

[airflow] 10/11: Only load distribution of a name once (#25296)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to tag v2.3.3+astro.2
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 966cc8e43a96f367084170cc19f108f8d24bf55e
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Aug 3 21:01:43 2022 +0800

    Only load distribution of a name once (#25296)
    
    (cherry picked from commit c30dc5e64d7229cbf8e9fbe84cfa790dfef5fb8c)
---
 airflow/plugins_manager.py            |  2 +-
 airflow/utils/entry_points.py         | 22 +++++++++++-----
 tests/plugins/test_plugins_manager.py |  3 ++-
 tests/utils/test_entry_points.py      | 49 +++++++++++++++++++++++++++++++++++
 tests/www/views/test_views.py         |  2 +-
 5 files changed, 69 insertions(+), 9 deletions(-)

diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 82e295fa19..431d5fe55a 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -113,7 +113,7 @@ class EntryPointSource(AirflowPluginSource):
     """Class used to define Plugins loaded from entrypoint."""
 
     def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
-        self.dist = dist.metadata['name']
+        self.dist = dist.metadata['Name']
         self.version = dist.version
         self.entrypoint = str(entrypoint)
 
diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py
index 668ed9b994..483f9efe77 100644
--- a/airflow/utils/entry_points.py
+++ b/airflow/utils/entry_points.py
@@ -15,15 +15,20 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import annotations
+
+from typing import Iterator
+
+from packaging.utils import canonicalize_name
+
 try:
-    import importlib_metadata
+    import importlib_metadata as metadata
 except ImportError:
-    from importlib import metadata as importlib_metadata  # type: ignore
+    from importlib import metadata  # type: ignore[no-redef]
 
 
-def entry_points_with_dist(group: str):
-    """
-    Return EntryPoint objects of the given group, along with the distribution information.
+def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]:
+    """Retrieve entry points of the given group.
 
     This is like the ``entry_points()`` function from importlib.metadata,
     except it also returns the distribution the entry_point was loaded from.
@@ -31,7 +36,12 @@ def entry_points_with_dist(group: str):
     :param group: Filter results to only this entrypoint group
     :return: Generator of (EntryPoint, Distribution) objects for the specified groups
     """
-    for dist in importlib_metadata.distributions():
+    loaded: set[str] = set()
+    for dist in metadata.distributions():
+        key = canonicalize_name(dist.metadata["Name"])
+        if key in loaded:
+            continue
+        loaded.add(key)
         for e in dist.entry_points:
             if e.group != group:
                 continue
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index f97a811c91..c46b6e83f2 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -298,6 +298,7 @@ class TestPluginsManager:
         from airflow.plugins_manager import import_errors, load_entrypoint_plugins
 
         mock_dist = mock.Mock()
+        mock_dist.metadata = {"Name": "test-dist"}
 
         mock_entrypoint = mock.Mock()
         mock_entrypoint.name = 'test-entrypoint'
@@ -387,7 +388,7 @@ class TestEntryPointSource:
         mock_entrypoint.module = 'module_name_plugin'
 
         mock_dist = mock.Mock()
-        mock_dist.metadata = {'name': 'test-entrypoint-plugin'}
+        mock_dist.metadata = {'Name': 'test-entrypoint-plugin'}
         mock_dist.version = '1.0.0'
         mock_dist.entry_points = [mock_entrypoint]
 
diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py
new file mode 100644
index 0000000000..65f688647e
--- /dev/null
+++ b/tests/utils/test_entry_points.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Iterable
+from unittest import mock
+
+from airflow.utils.entry_points import entry_points_with_dist, metadata
+
+
+class MockDistribution:
+    def __init__(self, name: str, entry_points: Iterable[metadata.EntryPoint]) -> None:
+        self.metadata = {"Name": name}
+        self.entry_points = entry_points
+
+
+class MockMetadata:
+    def distributions(self):
+        return [
+            MockDistribution(
+                "dist1",
+                [metadata.EntryPoint("a", "b", "group_x"), metadata.EntryPoint("c", "d", "group_y")],
+            ),
+            MockDistribution("Dist2", [metadata.EntryPoint("e", "f", "group_x")]),
+            MockDistribution("dist2", [metadata.EntryPoint("g", "h", "group_x")]),  # Duplicated name.
+        ]
+
+
+@mock.patch("airflow.utils.entry_points.metadata", MockMetadata())
+def test_entry_points_with_dist():
+    entries = list(entry_points_with_dist("group_x"))
+
+    # The second "dist2" is ignored. Only "group_x" entries are loaded.
+    assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"]
+    assert [ep.name for ep, _ in entries] == ["a", "e"]
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index fa79e145cb..b6d21a3f26 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -87,7 +87,7 @@ def test_plugin_should_list_entrypoint_on_page_with_details(admin_client):
     mock_plugin = AirflowPlugin()
     mock_plugin.name = "test_plugin"
     mock_plugin.source = EntryPointSource(
-        mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'})
+        mock.Mock(), mock.Mock(version='1.0.0', metadata={'Name': 'test-entrypoint-testpluginview'})
     )
     with mock_plugin_manager(plugins=[mock_plugin]):
         resp = admin_client.get('/plugin')