You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/29 15:58:18 UTC

[airflow] branch master updated: Faster default role syncing during webserver start (#15017)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1627323  Faster default role syncing during webserver start (#15017)
1627323 is described below

commit 1627323a197bba2c4fbd71816a9a6bd3f78c1657
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Mon Mar 29 09:58:01 2021 -0600

    Faster default role syncing during webserver start (#15017)
    
    This makes a handful of bigger queries instead of many queries when
    syncing the default Airflow roles. On my machine with 5k DAGs, this led
    to a reduction of 1 second in startup time (bonus, makes tests faster
    too).
---
 airflow/www/security.py    | 67 ++++++++++++++++++++++++++++++++++++++--------
 tests/www/test_security.py | 58 ++++++++++++++++++++++++++++++++++-----
 2 files changed, 108 insertions(+), 17 deletions(-)

diff --git a/airflow/www/security.py b/airflow/www/security.py
index 0678800..5431fd6 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -17,7 +17,8 @@
 # under the License.
 #
 
-from typing import Optional, Sequence, Set, Tuple
+import warnings
+from typing import Dict, Optional, Sequence, Set, Tuple
 
 from flask import current_app, g
 from flask_appbuilder.security.sqla import models as sqla_models
@@ -174,16 +175,34 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):  # pylint: disable=
     def init_role(self, role_name, perms):
         """
         Initialize the role with the permissions and related view-menus.
-
         :param role_name:
         :param perms:
         :return:
         """
-        role = self.find_role(role_name)
-        if not role:
-            role = self.add_role(role_name)
+        warnings.warn(
+            "`init_role` has been deprecated. Please use `bulk_sync_roles` instead.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        self.bulk_sync_roles([{'role': role_name, 'perms': perms}])
+
+    def bulk_sync_roles(self, roles):
+        """Sync the provided roles and permissions."""
+        existing_roles = self._get_all_roles_with_permissions()
+        pvs = self._get_all_non_dag_permissionviews()
+
+        for config in roles:
+            role_name = config['role']
+            perms = config['perms']
+            role = existing_roles.get(role_name) or self.add_role(role_name)
+
+            for perm_name, view_name in perms:
+                perm_view = pvs.get((perm_name, view_name)) or self.add_permission_view_menu(
+                    perm_name, view_name
+                )
 
-        self.add_permissions(role, set(perms))
+                if perm_view not in role.permissions:
+                    self.add_permission_role(role, perm_view)
 
     def add_permissions(self, role, perms):
         """Adds resource permissions to a given role."""
@@ -467,6 +486,34 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):  # pylint: disable=
             .all()
         )
 
+    def _get_all_non_dag_permissionviews(self) -> Dict[Tuple[str, str], PermissionView]:
+        """
+        Returns a dict with a key of (perm name, view menu name) and value of perm view
+        with all perm views except those that are for specific DAGs.
+        """
+        return {
+            (perm_name, viewmodel_name): viewmodel
+            for perm_name, viewmodel_name, viewmodel in (
+                self.get_session.query(self.permissionview_model)
+                .join(self.permission_model)
+                .join(self.viewmenu_model)
+                .filter(~self.viewmenu_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
+                .with_entities(
+                    self.permission_model.name, self.viewmenu_model.name, self.permissionview_model
+                )
+                .all()
+            )
+        }
+
+    def _get_all_roles_with_permissions(self) -> Dict[str, Role]:
+        """Returns a dict with a key of role name and value of role with eagrly loaded permissions"""
+        return {
+            r.name: r
+            for r in (
+                self.get_session.query(self.role_model).options(joinedload(self.role_model.permissions)).all()
+            )
+        }
+
     def create_dag_specific_permissions(self) -> None:
         """
         Creates 'can_read' and 'can_edit' permissions for all active and paused DAGs.
@@ -526,11 +573,9 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):  # pylint: disable=
         self.create_perm_vm_for_all_dag()
         self.create_dag_specific_permissions()
 
-        # Create default user role.
-        for config in self.ROLE_CONFIGS:
-            role = config['role']
-            perms = config['perms']
-            self.init_role(role, perms)
+        # Sync the default roles (Admin, Viewer, User, Op, public) with related permissions
+        self.bulk_sync_roles(self.ROLE_CONFIGS)
+
         self.add_homepage_access_to_custom_roles()
         # init existing roles, the rest role could be created through UI.
         self.update_admin_perm_view()
diff --git a/tests/www/test_security.py b/tests/www/test_security.py
index f2dc96a..d334517 100644
--- a/tests/www/test_security.py
+++ b/tests/www/test_security.py
@@ -103,7 +103,7 @@ class TestSecurity(unittest.TestCase):
             fab_utils.delete_role(cls.app, role_name)
 
     def expect_user_is_in_role(self, user, rolename):
-        self.security_manager.init_role(rolename, [])
+        self.security_manager.bulk_sync_roles([{'role': rolename, 'perms': []}])
         role = self.security_manager.find_role(rolename)
         if not role:
             self.security_manager.add_role(rolename)
@@ -141,14 +141,28 @@ class TestSecurity(unittest.TestCase):
         log.debug("Complete teardown!")
 
     def test_init_role_baseview(self):
+        role_name = 'MyRole7'
+        role_perms = [('can_some_other_action', 'AnotherBaseView')]
+        with pytest.warns(
+            DeprecationWarning,
+            match="`init_role` has been deprecated\\. Please use `bulk_sync_roles` instead\\.",
+        ):
+            self.security_manager.init_role(role_name, role_perms)
+
+        role = self.appbuilder.sm.find_role(role_name)
+        assert role is not None
+        assert len(role_perms) == len(role.permissions)
+
+    def test_bulk_sync_roles_baseview(self):
         role_name = 'MyRole3'
         role_perms = [('can_some_action', 'SomeBaseView')]
-        self.security_manager.init_role(role_name, perms=role_perms)
+        self.security_manager.bulk_sync_roles([{'role': role_name, 'perms': role_perms}])
+
         role = self.appbuilder.sm.find_role(role_name)
         assert role is not None
         assert len(role_perms) == len(role.permissions)
 
-    def test_init_role_modelview(self):
+    def test_bulk_sync_roles_modelview(self):
         role_name = 'MyRole2'
         role_perms = [
             ('can_list', 'SomeModelView'),
@@ -157,24 +171,33 @@ class TestSecurity(unittest.TestCase):
             (permissions.ACTION_CAN_EDIT, 'SomeModelView'),
             (permissions.ACTION_CAN_DELETE, 'SomeModelView'),
         ]
-        self.security_manager.init_role(role_name, role_perms)
+        mock_roles = [{'role': role_name, 'perms': role_perms}]
+        self.security_manager.bulk_sync_roles(mock_roles)
+
         role = self.appbuilder.sm.find_role(role_name)
         assert role is not None
         assert len(role_perms) == len(role.permissions)
 
+        # Check short circuit works
+        with assert_queries_count(2):  # One for permissionview, one for roles
+            self.security_manager.bulk_sync_roles(mock_roles)
+
     def test_update_and_verify_permission_role(self):
         role_name = 'Test_Role'
-        self.security_manager.init_role(role_name, [])
+        role_perms = []
+        mock_roles = [{'role': role_name, 'perms': role_perms}]
+        self.security_manager.bulk_sync_roles(mock_roles)
         role = self.security_manager.find_role(role_name)
 
         perm = self.security_manager.find_permission_view_menu(permissions.ACTION_CAN_EDIT, 'RoleModelView')
         self.security_manager.add_permission_role(role, perm)
         role_perms_len = len(role.permissions)
 
-        self.security_manager.init_role(role_name, [])
+        self.security_manager.bulk_sync_roles(mock_roles)
         new_role_perms_len = len(role.permissions)
 
         assert role_perms_len == new_role_perms_len
+        assert new_role_perms_len == 1
 
     def test_verify_public_role_has_no_permissions(self):
         public = self.appbuilder.sm.find_role("Public")
@@ -574,3 +597,26 @@ class TestSecurity(unittest.TestCase):
             assert len(perm) == 2
 
         assert ('can_read', 'Connections') in perms
+
+    def test_get_all_non_dag_permissionviews(self):
+        with assert_queries_count(1):
+            pvs = self.security_manager._get_all_non_dag_permissionviews()
+
+        assert isinstance(pvs, dict)
+        for (perm_name, viewmodel_name), perm_view in pvs.items():
+            assert isinstance(perm_name, str)
+            assert isinstance(viewmodel_name, str)
+            assert isinstance(perm_view, self.security_manager.permissionview_model)
+
+        assert ('can_read', 'Connections') in pvs
+
+    def test_get_all_roles_with_permissions(self):
+        with assert_queries_count(1):
+            roles = self.security_manager._get_all_roles_with_permissions()
+
+        assert isinstance(roles, dict)
+        for role_name, role in roles.items():
+            assert isinstance(role_name, str)
+            assert isinstance(role, self.security_manager.role_model)
+
+        assert 'Admin' in roles