You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/04/15 11:28:48 UTC

[airflow] 07/08: Fix missing on_load trigger for folder-based plugins (#15208)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7b9f0915144af397bd4b2e465af0875a3a6a6bc3
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Tue Apr 6 15:48:12 2021 -0600

    Fix missing on_load trigger for folder-based plugins (#15208)
    
    (cherry picked from commit 97b7780df48b412e104ff4adeecbe715264f00eb)
---
 airflow/plugins_manager.py            | 23 +++++++++-------
 tests/plugins/test_plugin.py          |  7 +++++
 tests/plugins/test_plugins_manager.py | 49 +++++++++++++++++++++++++++++++++++
 3 files changed, 70 insertions(+), 9 deletions(-)

diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index b68dbb9..cf957ff 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -173,13 +173,23 @@ def is_valid_plugin(plugin_obj):
     return False
 
 
+def register_plugin(plugin_instance):
+    """
+    Start plugin load and register it after success initialization
+
+    :param plugin_instance: subclass of AirflowPlugin
+    """
+    global plugins  # pylint: disable=global-statement
+    plugin_instance.on_load()
+    plugins.append(plugin_instance)
+
+
 def load_entrypoint_plugins():
     """
     Load and register plugins AirflowPlugin subclasses from the entrypoints.
     The entry_point group should be 'airflow.plugins'.
     """
     global import_errors  # pylint: disable=global-statement
-    global plugins  # pylint: disable=global-statement
 
     log.debug("Loading plugins from entrypoints")
 
@@ -191,10 +201,8 @@ def load_entrypoint_plugins():
                 continue
 
             plugin_instance = plugin_class()
-            if callable(getattr(plugin_instance, 'on_load', None)):
-                plugin_instance.on_load()
-                plugin_instance.source = EntryPointSource(entry_point, dist)
-                plugins.append(plugin_instance)
+            plugin_instance.source = EntryPointSource(entry_point, dist)
+            register_plugin(plugin_instance)
         except Exception as e:  # pylint: disable=broad-except
             log.exception("Failed to import plugin %s", entry_point.name)
             import_errors[entry_point.module] = str(e)
@@ -203,11 +211,9 @@ def load_entrypoint_plugins():
 def load_plugins_from_plugin_directory():
     """Load and register Airflow Plugins from plugins directory"""
     global import_errors  # pylint: disable=global-statement
-    global plugins  # pylint: disable=global-statement
     log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER)
 
     for file_path in find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore"):
-
         if not os.path.isfile(file_path):
             continue
         mod_name, file_ext = os.path.splitext(os.path.split(file_path)[-1])
@@ -225,8 +231,7 @@ def load_plugins_from_plugin_directory():
             for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)):
                 plugin_instance = mod_attr_value()
                 plugin_instance.source = PluginsDirectorySource(file_path)
-                plugins.append(plugin_instance)
-
+                register_plugin(plugin_instance)
         except Exception as e:  # pylint: disable=broad-except
             log.exception(e)
             log.error('Failed to import plugin %s', file_path)
diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py
index d52d8e5..ca02a39 100644
--- a/tests/plugins/test_plugin.py
+++ b/tests/plugins/test_plugin.py
@@ -127,3 +127,10 @@ class MockPluginB(AirflowPlugin):
 
 class MockPluginC(AirflowPlugin):
     name = 'plugin-c'
+
+
+class AirflowTestOnLoadPlugin(AirflowPlugin):
+    name = 'preload'
+
+    def on_load(self, *args, **kwargs):
+        self.name = 'postload'
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index f730f17..7c4d86a 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -17,18 +17,33 @@
 # under the License.
 import importlib
 import logging
+import os
 import sys
+import tempfile
 import unittest
 from unittest import mock
 
+import pytest
+
 from airflow.hooks.base import BaseHook
 from airflow.plugins_manager import AirflowPlugin
 from airflow.www import app as application
+from tests.test_utils.config import conf_vars
 from tests.test_utils.mock_plugins import mock_plugin_manager
 
 py39 = sys.version_info >= (3, 9)
 importlib_metadata = 'importlib.metadata' if py39 else 'importlib_metadata'
 
+ON_LOAD_EXCEPTION_PLUGIN = """
+from airflow.plugins_manager import AirflowPlugin
+
+class AirflowTestOnLoadExceptionPlugin(AirflowPlugin):
+    name = 'preload'
+
+    def on_load(self, *args, **kwargs):
+        raise Exception("oops")
+"""
+
 
 class TestPluginsRBAC(unittest.TestCase):
     def setUp(self):
@@ -145,6 +160,40 @@ class TestPluginsManager:
         assert caplog.records[-1].levelname == 'DEBUG'
         assert caplog.records[-1].msg == 'Loading %d plugin(s) took %.2f seconds'
 
+    def test_loads_filesystem_plugins(self, caplog):
+        from airflow import plugins_manager
+
+        with mock.patch('airflow.plugins_manager.plugins', []):
+            plugins_manager.load_plugins_from_plugin_directory()
+
+            assert 5 == len(plugins_manager.plugins)
+            for plugin in plugins_manager.plugins:
+                if 'AirflowTestOnLoadPlugin' not in str(plugin):
+                    continue
+                assert 'postload' == plugin.name
+                break
+            else:
+                pytest.fail("Wasn't able to find a registered `AirflowTestOnLoadPlugin`")
+
+            assert caplog.record_tuples == []
+
+    def test_loads_filesystem_plugins_exception(self, caplog):
+        from airflow import plugins_manager
+
+        with mock.patch('airflow.plugins_manager.plugins', []):
+            with tempfile.TemporaryDirectory() as tmpdir:
+                with open(os.path.join(tmpdir, 'testplugin.py'), "w") as f:
+                    f.write(ON_LOAD_EXCEPTION_PLUGIN)
+
+                with conf_vars({('core', 'plugins_folder'): tmpdir}):
+                    plugins_manager.load_plugins_from_plugin_directory()
+
+            assert plugins_manager.plugins == []
+
+            received_logs = caplog.text
+            assert 'Failed to import plugin' in received_logs
+            assert 'testplugin.py' in received_logs
+
     def test_should_warning_about_incompatible_plugins(self, caplog):
         class AirflowAdminViewsPlugin(AirflowPlugin):
             name = "test_admin_views_plugin"