You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/11/29 06:21:24 UTC

[airflow] branch master updated: Replace pkg_resources with importlib.metadata to avoid VersionConflict errors (#12694)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ef9aa7  Replace pkg_resources with importlib.metadata to avoid VersionConflict errors (#12694)
7ef9aa7 is described below

commit 7ef9aa7d545f11442b6ebb86590cd8ce5f98430b
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun Nov 29 06:19:47 2020 +0000

    Replace pkg_resources with importlib.metadata to avoid VersionConflict errors (#12694)
    
    Using `pkg_resources.iter_entry_points` validates the version
    constraints, and if any fail it will throw an Exception for that
    entrypoint.
    
    This sounds nice, but is a huge mis-feature.
    
    So instead of that, switch to using importlib.metadata (well, it's
    backport importlib_metadata) that just gives us the entrypoints - no
    other verification of requirements is performed.
    
    This has two advantages:
    
    1. providers and plugins load much more reliably.
    2. it's faster too
    
    Closes #12692
---
 airflow/plugins_manager.py            | 46 +++++++++++++++-------
 airflow/providers_manager.py          | 19 +++------
 setup.cfg                             |  1 +
 tests/plugins/test_plugins_manager.py | 74 +++++++++++++++++++----------------
 tests/www/test_views.py               | 46 +++++-----------------
 5 files changed, 89 insertions(+), 97 deletions(-)

diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index fe94e21..dadae6a 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -27,7 +27,7 @@ import time
 import types
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
 
-import pkg_resources
+import importlib_metadata
 
 from airflow import settings
 from airflow.utils.file import find_path_from_directory
@@ -88,15 +88,16 @@ class PluginsDirectorySource(AirflowPluginSource):
 class EntryPointSource(AirflowPluginSource):
     """Class used to define Plugins loaded from entrypoint."""
 
-    def __init__(self, entrypoint):
-        self.dist = str(entrypoint.dist)
+    def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
+        self.dist = dist.metadata['name']
+        self.version = dist.version
         self.entrypoint = str(entrypoint)
 
     def __str__(self):
-        return f"{self.dist}: {self.entrypoint}"
+        return f"{self.dist}=={self.version}: {self.entrypoint}"
 
     def __html__(self):
-        return f"<em>{self.dist}:</em> {self.entrypoint}"
+        return f"<em>{self.dist}=={self.version}:</em> {self.entrypoint}"
 
 
 class AirflowPluginException(Exception):
@@ -169,6 +170,23 @@ def is_valid_plugin(plugin_obj):
     return False
 
 
+def entry_points_with_dist(group: str):
+    """
+    Return EntryPoint objects of the given group, along with the distribution information.
+
+    This is like the ``entry_points()`` function from importlib.metadata,
+    except it also returns the distribution the entry_point was loaded from.
+
+    :param group: FIlter results to only this entrypoint group
+    :return: Generator of (EntryPoint, Distribution) objects for the specified groups
+    """
+    for dist in importlib_metadata.distributions():
+        for e in dist.entry_points:
+            if e.group != group:
+                continue
+            yield (e, dist)
+
+
 def load_entrypoint_plugins():
     """
     Load and register plugins AirflowPlugin subclasses from the entrypoints.
@@ -177,20 +195,20 @@ def load_entrypoint_plugins():
     global import_errors  # pylint: disable=global-statement
     global plugins  # pylint: disable=global-statement
 
-    entry_points = pkg_resources.iter_entry_points('airflow.plugins')
-
     log.debug("Loading plugins from entrypoints")
 
-    for entry_point in entry_points:  # pylint: disable=too-many-nested-blocks
+    for entry_point, dist in entry_points_with_dist('airflow.plugins'):
         log.debug('Importing entry_point plugin %s', entry_point.name)
         try:
             plugin_class = entry_point.load()
-            if is_valid_plugin(plugin_class):
-                plugin_instance = plugin_class()
-                if callable(getattr(plugin_instance, 'on_load', None)):
-                    plugin_instance.on_load()
-                    plugin_instance.source = EntryPointSource(entry_point)
-                    plugins.append(plugin_instance)
+            if not is_valid_plugin(plugin_class):
+                continue
+
+            plugin_instance = plugin_class()
+            if callable(getattr(plugin_instance, 'on_load', None)):
+                plugin_instance.on_load()
+                plugin_instance.source = EntryPointSource(entry_point, dist)
+                plugins.append(plugin_instance)
         except Exception as e:  # pylint: disable=broad-except
             log.exception("Failed to import plugin %s", entry_point.name)
             import_errors[entry_point.module_name] = str(e)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 30041c9..44821f7 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -24,7 +24,6 @@ from collections import OrderedDict
 from typing import Dict, Tuple
 
 import jsonschema
-import pkg_resources
 import yaml
 
 try:
@@ -90,19 +89,13 @@ class ProvidersManager:
         via the 'apache_airflow_provider' entrypoint as a dictionary conforming to the
         'airflow/provider.yaml.schema.json' schema.
         """
-        for entry_point in pkg_resources.iter_entry_points('apache_airflow_provider'):
-            package_name = entry_point.dist.project_name
+        from airflow.plugins_manager import entry_points_with_dist
+
+        for (entry_point, dist) in entry_points_with_dist('apache_airflow_provider'):
+            package_name = dist.metadata['name']
             log.debug("Loading %s from package %s", entry_point, package_name)
-            version = entry_point.dist.version
-            try:
-                provider_info = entry_point.load()()
-            except pkg_resources.VersionConflict as e:
-                log.warning(
-                    "The provider package %s could not be registered because of version conflict : %s",
-                    package_name,
-                    e,
-                )
-                continue
+            version = dist.version
+            provider_info = entry_point.load()()
             self._validator.validate(provider_info)
             provider_info_package_name = provider_info['package-name']
             if package_name != provider_info_package_name:
diff --git a/setup.cfg b/setup.cfg
index 38e05ef..282ceff 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -100,6 +100,7 @@ install_requires =
     funcsigs>=1.0.0, <2.0.0
     graphviz>=0.12
     gunicorn>=19.5.0, <20.0
+    importlib_metadata~=1.7 # We could work with 3.1, but argparse needs <2
     importlib_resources~=1.4
     iso8601>=0.1.12
     itsdangerous>=1.1.0
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index 48407ae..117df98 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -78,30 +78,6 @@ class TestPluginsRBAC(unittest.TestCase):
         self.assertTrue('test_plugin' in self.app.blueprints)
         self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name)
 
-    @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
-    def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_ep_plugins):
-        """
-        Test that Airflow does not raise an Error if there is any Exception because of the
-        Plugin.
-        """
-        from airflow.plugins_manager import import_errors, load_entrypoint_plugins
-
-        mock_entrypoint = mock.Mock()
-        mock_entrypoint.name = 'test-entrypoint'
-        mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
-        mock_entrypoint.load.side_effect = Exception('Version Conflict')
-        mock_ep_plugins.return_value = [mock_entrypoint]
-
-        with self.assertLogs("airflow.plugins_manager", level="ERROR") as log_output:
-            load_entrypoint_plugins()
-
-            received_logs = log_output.output[0]
-            # Assert Traceback is shown too
-            assert "Traceback (most recent call last):" in received_logs
-            assert "Version Conflict" in received_logs
-            assert "Failed to import plugin test-entrypoint" in received_logs
-            assert ("test.plugins.test_plugins_manager", "Version Conflict") in import_errors.items()
-
 
 class TestPluginsManager:
     def test_no_log_when_no_plugins(self, caplog):
@@ -210,6 +186,33 @@ class TestPluginsManager:
 
         assert caplog.record_tuples == []
 
+    def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog):
+        """
+        Test that Airflow does not raise an error if there is any Exception because of a plugin.
+        """
+        from airflow.plugins_manager import import_errors, load_entrypoint_plugins
+
+        mock_dist = mock.Mock()
+
+        mock_entrypoint = mock.Mock()
+        mock_entrypoint.name = 'test-entrypoint'
+        mock_entrypoint.group = 'airflow.plugins'
+        mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
+        mock_entrypoint.load.side_effect = ImportError('my_fake_module not found')
+        mock_dist.entry_points = [mock_entrypoint]
+
+        with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]), caplog.at_level(
+            logging.ERROR, logger='airflow.plugins_manager'
+        ):
+            load_entrypoint_plugins()
+
+            received_logs = caplog.text
+            # Assert Traceback is shown too
+            assert "Traceback (most recent call last):" in received_logs
+            assert "my_fake_module not found" in received_logs
+            assert "Failed to import plugin test-entrypoint" in received_logs
+            assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items()
+
 
 class TestPluginsDirectorySource(unittest.TestCase):
     def test_should_return_correct_path_name(self):
@@ -221,20 +224,23 @@ class TestPluginsDirectorySource(unittest.TestCase):
         self.assertEqual("<em>$PLUGINS_FOLDER/</em>test_plugins_manager.py", source.__html__())
 
 
-class TestEntryPointSource(unittest.TestCase):
-    @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
-    def test_should_return_correct_source_details(self, mock_ep_plugins):
+class TestEntryPointSource:
+    def test_should_return_correct_source_details(self):
         from airflow import plugins_manager
 
         mock_entrypoint = mock.Mock()
         mock_entrypoint.name = 'test-entrypoint-plugin'
         mock_entrypoint.module_name = 'module_name_plugin'
-        mock_entrypoint.dist = 'test-entrypoint-plugin==1.0.0'
-        mock_ep_plugins.return_value = [mock_entrypoint]
 
-        plugins_manager.load_entrypoint_plugins()
+        mock_dist = mock.Mock()
+        mock_dist.metadata = {'name': 'test-entrypoint-plugin'}
+        mock_dist.version = '1.0.0'
+        mock_dist.entry_points = [mock_entrypoint]
+
+        with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]):
+            plugins_manager.load_entrypoint_plugins()
 
-        source = plugins_manager.EntryPointSource(mock_entrypoint)
-        self.assertEqual(str(mock_entrypoint), source.entrypoint)
-        self.assertEqual("test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint), str(source))
-        self.assertEqual("<em>test-entrypoint-plugin==1.0.0:</em> " + str(mock_entrypoint), source.__html__())
+        source = plugins_manager.EntryPointSource(mock_entrypoint, mock_dist)
+        assert str(mock_entrypoint) == source.entrypoint
+        assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == str(source)
+        assert "<em>test-entrypoint-plugin==1.0.0:</em> " + str(mock_entrypoint) == source.__html__()
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index 9208f32..d4d0572 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -52,7 +52,7 @@ from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.bash import BashOperator
 from airflow.operators.dummy_operator import DummyOperator
-from airflow.plugins_manager import AirflowPlugin, EntryPointSource, PluginsDirectorySource
+from airflow.plugins_manager import AirflowPlugin, EntryPointSource
 from airflow.security import permissions
 from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES
 from airflow.utils import dates, timezone
@@ -67,6 +67,7 @@ from tests.test_utils import fab_utils
 from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_runs
+from tests.test_utils.mock_plugins import mock_plugin_manager
 
 
 class TemplateWithContext(NamedTuple):
@@ -337,10 +338,6 @@ class PluginOperator(BaseOperator):
     pass
 
 
-class EntrypointPlugin(AirflowPlugin):
-    name = 'test-entrypoint-testpluginview'
-
-
 class TestPluginView(TestBase):
     def test_should_list_plugins_on_page_with_details(self):
         resp = self.client.get('/plugin')
@@ -349,19 +346,15 @@ class TestPluginView(TestBase):
         self.check_content_in_response("source", resp)
         self.check_content_in_response("<em>$PLUGINS_FOLDER/</em>test_plugin.py", resp)
 
-    @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
-    def test_should_list_entrypoint_plugins_on_page_with_details(self, mock_ep_plugins):
-        from airflow.plugins_manager import load_entrypoint_plugins
-
-        mock_entrypoint = mock.Mock()
-        mock_entrypoint.name = 'test-entrypoint-testpluginview'
-        mock_entrypoint.module_name = 'module_name_testpluginview'
-        mock_entrypoint.dist = 'test-entrypoint-testpluginview==1.0.0'
-        mock_entrypoint.load.return_value = EntrypointPlugin
-        mock_ep_plugins.return_value = [mock_entrypoint]
+    def test_should_list_entrypoint_plugins_on_page_with_details(self):
 
-        load_entrypoint_plugins()
-        resp = self.client.get('/plugin')
+        mock_plugin = AirflowPlugin()
+        mock_plugin.name = "test_plugin"
+        mock_plugin.source = EntryPointSource(
+            mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'})
+        )
+        with mock_plugin_manager(plugins=[mock_plugin]):
+            resp = self.client.get('/plugin')
 
         self.check_content_in_response("test_plugin", resp)
         self.check_content_in_response("Airflow Plugins", resp)
@@ -369,25 +362,6 @@ class TestPluginView(TestBase):
         self.check_content_in_response("<em>test-entrypoint-testpluginview==1.0.0:</em> <Mock id=", resp)
 
 
-class TestPluginsDirectorySource(unittest.TestCase):
-    def test_should_provide_correct_attribute_values(self):
-        source = PluginsDirectorySource("./test_views.py")
-        self.assertEqual("$PLUGINS_FOLDER/../../test_views.py", str(source))
-        self.assertEqual("<em>$PLUGINS_FOLDER/</em>../../test_views.py", source.__html__())
-        self.assertEqual("../../test_views.py", source.path)
-
-
-class TestEntryPointSource(unittest.TestCase):
-    def test_should_provide_correct_attribute_values(self):
-        mock_entrypoint = mock.Mock()
-        mock_entrypoint.dist = 'test-entrypoint-dist==1.0.0'
-        source = EntryPointSource(mock_entrypoint)
-        self.assertEqual("test-entrypoint-dist==1.0.0", source.dist)
-        self.assertEqual(str(mock_entrypoint), source.entrypoint)
-        self.assertEqual("test-entrypoint-dist==1.0.0: " + str(mock_entrypoint), str(source))
-        self.assertEqual("<em>test-entrypoint-dist==1.0.0:</em> " + str(mock_entrypoint), source.__html__())
-
-
 class TestPoolModelView(TestBase):
     def setUp(self):
         super().setUp()