You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/11/29 06:21:24 UTC
[airflow] branch master updated: Replace pkg_resources with
importlib.metadata to avoid VersionConflict errors (#12694)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 7ef9aa7 Replace pkg_resources with importlib.metadata to avoid VersionConflict errors (#12694)
7ef9aa7 is described below
commit 7ef9aa7d545f11442b6ebb86590cd8ce5f98430b
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun Nov 29 06:19:47 2020 +0000
Replace pkg_resources with importlib.metadata to avoid VersionConflict errors (#12694)
Using `pkg_resources.iter_entry_points` validates the version
constraints, and if any fail it will throw an Exception for that
entrypoint.
This sounds nice, but is a huge mis-feature.
So instead of that, switch to using importlib.metadata (well, it's
backport importlib_metadata) that just gives us the entrypoints - no
other verification of requirements is performed.
This has two advantages:
1. providers and plugins load much more reliably.
2. it's faster too
Closes #12692
---
airflow/plugins_manager.py | 46 +++++++++++++++-------
airflow/providers_manager.py | 19 +++------
setup.cfg | 1 +
tests/plugins/test_plugins_manager.py | 74 +++++++++++++++++++----------------
tests/www/test_views.py | 46 +++++-----------------
5 files changed, 89 insertions(+), 97 deletions(-)
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index fe94e21..dadae6a 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -27,7 +27,7 @@ import time
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
-import pkg_resources
+import importlib_metadata
from airflow import settings
from airflow.utils.file import find_path_from_directory
@@ -88,15 +88,16 @@ class PluginsDirectorySource(AirflowPluginSource):
class EntryPointSource(AirflowPluginSource):
"""Class used to define Plugins loaded from entrypoint."""
- def __init__(self, entrypoint):
- self.dist = str(entrypoint.dist)
+ def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
+ self.dist = dist.metadata['name']
+ self.version = dist.version
self.entrypoint = str(entrypoint)
def __str__(self):
- return f"{self.dist}: {self.entrypoint}"
+ return f"{self.dist}=={self.version}: {self.entrypoint}"
def __html__(self):
- return f"<em>{self.dist}:</em> {self.entrypoint}"
+ return f"<em>{self.dist}=={self.version}:</em> {self.entrypoint}"
class AirflowPluginException(Exception):
@@ -169,6 +170,23 @@ def is_valid_plugin(plugin_obj):
return False
+def entry_points_with_dist(group: str):
+ """
+ Return EntryPoint objects of the given group, along with the distribution information.
+
+ This is like the ``entry_points()`` function from importlib.metadata,
+ except it also returns the distribution the entry_point was loaded from.
+
+ :param group: FIlter results to only this entrypoint group
+ :return: Generator of (EntryPoint, Distribution) objects for the specified groups
+ """
+ for dist in importlib_metadata.distributions():
+ for e in dist.entry_points:
+ if e.group != group:
+ continue
+ yield (e, dist)
+
+
def load_entrypoint_plugins():
"""
Load and register plugins AirflowPlugin subclasses from the entrypoints.
@@ -177,20 +195,20 @@ def load_entrypoint_plugins():
global import_errors # pylint: disable=global-statement
global plugins # pylint: disable=global-statement
- entry_points = pkg_resources.iter_entry_points('airflow.plugins')
-
log.debug("Loading plugins from entrypoints")
- for entry_point in entry_points: # pylint: disable=too-many-nested-blocks
+ for entry_point, dist in entry_points_with_dist('airflow.plugins'):
log.debug('Importing entry_point plugin %s', entry_point.name)
try:
plugin_class = entry_point.load()
- if is_valid_plugin(plugin_class):
- plugin_instance = plugin_class()
- if callable(getattr(plugin_instance, 'on_load', None)):
- plugin_instance.on_load()
- plugin_instance.source = EntryPointSource(entry_point)
- plugins.append(plugin_instance)
+ if not is_valid_plugin(plugin_class):
+ continue
+
+ plugin_instance = plugin_class()
+ if callable(getattr(plugin_instance, 'on_load', None)):
+ plugin_instance.on_load()
+ plugin_instance.source = EntryPointSource(entry_point, dist)
+ plugins.append(plugin_instance)
except Exception as e: # pylint: disable=broad-except
log.exception("Failed to import plugin %s", entry_point.name)
import_errors[entry_point.module_name] = str(e)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 30041c9..44821f7 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -24,7 +24,6 @@ from collections import OrderedDict
from typing import Dict, Tuple
import jsonschema
-import pkg_resources
import yaml
try:
@@ -90,19 +89,13 @@ class ProvidersManager:
via the 'apache_airflow_provider' entrypoint as a dictionary conforming to the
'airflow/provider.yaml.schema.json' schema.
"""
- for entry_point in pkg_resources.iter_entry_points('apache_airflow_provider'):
- package_name = entry_point.dist.project_name
+ from airflow.plugins_manager import entry_points_with_dist
+
+ for (entry_point, dist) in entry_points_with_dist('apache_airflow_provider'):
+ package_name = dist.metadata['name']
log.debug("Loading %s from package %s", entry_point, package_name)
- version = entry_point.dist.version
- try:
- provider_info = entry_point.load()()
- except pkg_resources.VersionConflict as e:
- log.warning(
- "The provider package %s could not be registered because of version conflict : %s",
- package_name,
- e,
- )
- continue
+ version = dist.version
+ provider_info = entry_point.load()()
self._validator.validate(provider_info)
provider_info_package_name = provider_info['package-name']
if package_name != provider_info_package_name:
diff --git a/setup.cfg b/setup.cfg
index 38e05ef..282ceff 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -100,6 +100,7 @@ install_requires =
funcsigs>=1.0.0, <2.0.0
graphviz>=0.12
gunicorn>=19.5.0, <20.0
+ importlib_metadata~=1.7 # We could work with 3.1, but argparse needs <2
importlib_resources~=1.4
iso8601>=0.1.12
itsdangerous>=1.1.0
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index 48407ae..117df98 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -78,30 +78,6 @@ class TestPluginsRBAC(unittest.TestCase):
self.assertTrue('test_plugin' in self.app.blueprints)
self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name)
- @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
- def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_ep_plugins):
- """
- Test that Airflow does not raise an Error if there is any Exception because of the
- Plugin.
- """
- from airflow.plugins_manager import import_errors, load_entrypoint_plugins
-
- mock_entrypoint = mock.Mock()
- mock_entrypoint.name = 'test-entrypoint'
- mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
- mock_entrypoint.load.side_effect = Exception('Version Conflict')
- mock_ep_plugins.return_value = [mock_entrypoint]
-
- with self.assertLogs("airflow.plugins_manager", level="ERROR") as log_output:
- load_entrypoint_plugins()
-
- received_logs = log_output.output[0]
- # Assert Traceback is shown too
- assert "Traceback (most recent call last):" in received_logs
- assert "Version Conflict" in received_logs
- assert "Failed to import plugin test-entrypoint" in received_logs
- assert ("test.plugins.test_plugins_manager", "Version Conflict") in import_errors.items()
-
class TestPluginsManager:
def test_no_log_when_no_plugins(self, caplog):
@@ -210,6 +186,33 @@ class TestPluginsManager:
assert caplog.record_tuples == []
+ def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog):
+ """
+ Test that Airflow does not raise an error if there is any Exception because of a plugin.
+ """
+ from airflow.plugins_manager import import_errors, load_entrypoint_plugins
+
+ mock_dist = mock.Mock()
+
+ mock_entrypoint = mock.Mock()
+ mock_entrypoint.name = 'test-entrypoint'
+ mock_entrypoint.group = 'airflow.plugins'
+ mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
+ mock_entrypoint.load.side_effect = ImportError('my_fake_module not found')
+ mock_dist.entry_points = [mock_entrypoint]
+
+ with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]), caplog.at_level(
+ logging.ERROR, logger='airflow.plugins_manager'
+ ):
+ load_entrypoint_plugins()
+
+ received_logs = caplog.text
+ # Assert Traceback is shown too
+ assert "Traceback (most recent call last):" in received_logs
+ assert "my_fake_module not found" in received_logs
+ assert "Failed to import plugin test-entrypoint" in received_logs
+ assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items()
+
class TestPluginsDirectorySource(unittest.TestCase):
def test_should_return_correct_path_name(self):
@@ -221,20 +224,23 @@ class TestPluginsDirectorySource(unittest.TestCase):
self.assertEqual("<em>$PLUGINS_FOLDER/</em>test_plugins_manager.py", source.__html__())
-class TestEntryPointSource(unittest.TestCase):
- @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
- def test_should_return_correct_source_details(self, mock_ep_plugins):
+class TestEntryPointSource:
+ def test_should_return_correct_source_details(self):
from airflow import plugins_manager
mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint-plugin'
mock_entrypoint.module_name = 'module_name_plugin'
- mock_entrypoint.dist = 'test-entrypoint-plugin==1.0.0'
- mock_ep_plugins.return_value = [mock_entrypoint]
- plugins_manager.load_entrypoint_plugins()
+ mock_dist = mock.Mock()
+ mock_dist.metadata = {'name': 'test-entrypoint-plugin'}
+ mock_dist.version = '1.0.0'
+ mock_dist.entry_points = [mock_entrypoint]
+
+ with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]):
+ plugins_manager.load_entrypoint_plugins()
- source = plugins_manager.EntryPointSource(mock_entrypoint)
- self.assertEqual(str(mock_entrypoint), source.entrypoint)
- self.assertEqual("test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint), str(source))
- self.assertEqual("<em>test-entrypoint-plugin==1.0.0:</em> " + str(mock_entrypoint), source.__html__())
+ source = plugins_manager.EntryPointSource(mock_entrypoint, mock_dist)
+ assert str(mock_entrypoint) == source.entrypoint
+ assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == str(source)
+ assert "<em>test-entrypoint-plugin==1.0.0:</em> " + str(mock_entrypoint) == source.__html__()
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index 9208f32..d4d0572 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -52,7 +52,7 @@ from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
-from airflow.plugins_manager import AirflowPlugin, EntryPointSource, PluginsDirectorySource
+from airflow.plugins_manager import AirflowPlugin, EntryPointSource
from airflow.security import permissions
from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES
from airflow.utils import dates, timezone
@@ -67,6 +67,7 @@ from tests.test_utils import fab_utils
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_runs
+from tests.test_utils.mock_plugins import mock_plugin_manager
class TemplateWithContext(NamedTuple):
@@ -337,10 +338,6 @@ class PluginOperator(BaseOperator):
pass
-class EntrypointPlugin(AirflowPlugin):
- name = 'test-entrypoint-testpluginview'
-
-
class TestPluginView(TestBase):
def test_should_list_plugins_on_page_with_details(self):
resp = self.client.get('/plugin')
@@ -349,19 +346,15 @@ class TestPluginView(TestBase):
self.check_content_in_response("source", resp)
self.check_content_in_response("<em>$PLUGINS_FOLDER/</em>test_plugin.py", resp)
- @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
- def test_should_list_entrypoint_plugins_on_page_with_details(self, mock_ep_plugins):
- from airflow.plugins_manager import load_entrypoint_plugins
-
- mock_entrypoint = mock.Mock()
- mock_entrypoint.name = 'test-entrypoint-testpluginview'
- mock_entrypoint.module_name = 'module_name_testpluginview'
- mock_entrypoint.dist = 'test-entrypoint-testpluginview==1.0.0'
- mock_entrypoint.load.return_value = EntrypointPlugin
- mock_ep_plugins.return_value = [mock_entrypoint]
+ def test_should_list_entrypoint_plugins_on_page_with_details(self):
- load_entrypoint_plugins()
- resp = self.client.get('/plugin')
+ mock_plugin = AirflowPlugin()
+ mock_plugin.name = "test_plugin"
+ mock_plugin.source = EntryPointSource(
+ mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'})
+ )
+ with mock_plugin_manager(plugins=[mock_plugin]):
+ resp = self.client.get('/plugin')
self.check_content_in_response("test_plugin", resp)
self.check_content_in_response("Airflow Plugins", resp)
@@ -369,25 +362,6 @@ class TestPluginView(TestBase):
self.check_content_in_response("<em>test-entrypoint-testpluginview==1.0.0:</em> <Mock id=", resp)
-class TestPluginsDirectorySource(unittest.TestCase):
- def test_should_provide_correct_attribute_values(self):
- source = PluginsDirectorySource("./test_views.py")
- self.assertEqual("$PLUGINS_FOLDER/../../test_views.py", str(source))
- self.assertEqual("<em>$PLUGINS_FOLDER/</em>../../test_views.py", source.__html__())
- self.assertEqual("../../test_views.py", source.path)
-
-
-class TestEntryPointSource(unittest.TestCase):
- def test_should_provide_correct_attribute_values(self):
- mock_entrypoint = mock.Mock()
- mock_entrypoint.dist = 'test-entrypoint-dist==1.0.0'
- source = EntryPointSource(mock_entrypoint)
- self.assertEqual("test-entrypoint-dist==1.0.0", source.dist)
- self.assertEqual(str(mock_entrypoint), source.entrypoint)
- self.assertEqual("test-entrypoint-dist==1.0.0: " + str(mock_entrypoint), str(source))
- self.assertEqual("<em>test-entrypoint-dist==1.0.0:</em> " + str(mock_entrypoint), source.__html__())
-
-
class TestPoolModelView(TestBase):
def setUp(self):
super().setUp()