You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/01 16:38:27 UTC

[airflow] branch master updated: Replace get accessible dag ids (#11027)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 427a4a8  Replace get accessible dag ids (#11027)
427a4a8 is described below

commit 427a4a8f01c414ab571578bb6b8fbe5a8c6b32ef
Author: James Timmins <ja...@astronomer.io>
AuthorDate: Thu Oct 1 09:37:00 2020 -0700

    Replace get accessible dag ids (#11027)
---
 airflow/www/security.py    | 64 +++++++++++++++++------------------
 airflow/www/views.py       | 21 ++++++------
 tests/www/test_security.py | 84 ++++++++++++++++++++++++++++------------------
 tests/www/test_views.py    |  2 +-
 4 files changed, 93 insertions(+), 78 deletions(-)

diff --git a/airflow/www/security.py b/airflow/www/security.py
index 355ccf0..20686b7 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -22,7 +22,9 @@ from typing import Set
 from flask import current_app, g
 from flask_appbuilder.security.sqla import models as sqla_models
 from flask_appbuilder.security.sqla.manager import SecurityManager
+from flask_appbuilder.security.sqla.models import PermissionView, Role, User
 from sqlalchemy import and_, or_
+from sqlalchemy.orm import joinedload
 
 from airflow import models
 from airflow.exceptions import AirflowException
@@ -41,7 +43,9 @@ EXISTING_ROLES = {
 
 CAN_CREATE = 'can_create'
 CAN_READ = 'can_read'
+CAN_DAG_READ = 'can_dag_read'
 CAN_EDIT = 'can_edit'
+CAN_DAG_EDIT = 'can_dag_edit'
 CAN_DELETE = 'can_delete'
 
 
@@ -276,60 +280,54 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
 
     def get_readable_dags(self, user):
         """Gets the DAGs readable by authenticated user."""
-        return self.get_accessible_dags(CAN_READ, user)
+        return self.get_accessible_dags([CAN_READ, CAN_DAG_READ], user)
 
     def get_editable_dags(self, user):
         """Gets the DAGs editable by authenticated user."""
-        return self.get_accessible_dags(CAN_EDIT, user)
+        return self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT], user)
 
-    def get_readable_dag_ids(self, user):
+    def get_readable_dag_ids(self, user) -> Set[str]:
         """Gets the DAG IDs readable by authenticated user."""
-        return [dag.dag_id for dag in self.get_readable_dags(user)]
+        return set(dag.dag_id for dag in self.get_readable_dags(user))
 
-    def get_editable_dag_ids(self, user):
+    def get_editable_dag_ids(self, user) -> Set[str]:
         """Gets the DAG IDs editable by authenticated user."""
-        return [dag.dag_id for dag in self.get_editable_dags(user)]
+        return set(dag.dag_id for dag in self.get_editable_dags(user))
+
+    def get_accessible_dag_ids(self, user) -> Set[str]:
+        """Gets the DAG IDs editable or readable by authenticated user."""
+        accessible_dags = self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT, CAN_READ, CAN_DAG_READ], user)
+        return set(dag.dag_id for dag in accessible_dags)
 
     @provide_session
-    def get_accessible_dags(self, user_action, user, session=None):
+    def get_accessible_dags(self, user_actions, user, session=None):
         """Generic function to get readable or writable DAGs for authenticated user."""
         if user.is_anonymous:
             return set()
 
+        user_query = (
+            session.query(User)
+            .options(
+                joinedload(User.roles)
+                .subqueryload(Role.permissions)
+                .options(joinedload(PermissionView.permission), joinedload(PermissionView.view_menu))
+            )
+            .filter(User.id == user.id)
+            .first()
+        )
         resources = set()
-        for role in user.roles:
+        for role in user_query.roles:
             for permission in role.permissions:
                 resource = permission.view_menu.name
                 action = permission.permission.name
-                if action == user_action:
+                if action in user_actions:
                     resources.add(resource)
-        if 'Dag' in resources:
+
+        if bool({'Dag', 'all_dags'}.intersection(resources)):
             return session.query(DagModel)
 
         return session.query(DagModel).filter(DagModel.dag_id.in_(resources))
 
-    def get_accessible_dag_ids(self, username=None) -> Set[str]:
-        """
-        Return a set of dags that user has access to(either read or write).
-
-        :param username: Name of the user.
-        :return: A set of dag ids that the user could access.
-        """
-        if not username:
-            username = g.user
-
-        if username.is_anonymous or 'Public' in username.roles:
-            # return an empty set if the role is public
-            return set()
-
-        roles = {role.name for role in username.roles}
-        if {'Admin', 'Viewer', 'User', 'Op'} & roles:
-            return self.DAG_VMS
-
-        user_perms_views = self.get_all_permissions_views()
-        # return a set of all dags that the user could access
-        return {view for perm, view in user_perms_views if perm in self.DAG_PERMS}
-
     def has_access(self, permission, view_name, user=None) -> bool:
         """
         Verify whether a given user could perform certain permission
@@ -414,7 +412,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
 
     def _merge_perm(self, permission_name, view_menu_name):
         """
-        Add the new permission , view_menu to ab_permission_view_role if not exists.
+        Add the new (permission, view_menu) to assoc_permissionview_role if it doesn't exist.
         It will add the related entry to ab_permission
         and ab_view_menu two meta tables as well.
 
diff --git a/airflow/www/views.py b/airflow/www/views.py
index b6b1978..95a949b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -33,8 +33,8 @@ import lazy_object_proxy
 import nvd3
 import sqlalchemy as sqla
 from flask import (
-    Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request,
-    session as flask_session, url_for,
+    Markup, Response, current_app, escape, flash, g, jsonify, make_response, redirect, render_template,
+    request, session as flask_session, url_for,
 )
 from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name
 from flask_appbuilder.actions import action
@@ -442,7 +442,7 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: disable=too-many-public-m
         end = start + dags_per_page
 
         # Get all the dag id the user could access
-        filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
 
         with create_session() as session:
             # read orm_dags from the db
@@ -543,7 +543,7 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: disable=too-many-public-m
         """Dag statistics."""
         dr = models.DagRun
 
-        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
         if 'all_dags' in allowed_dag_ids:
             allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
 
@@ -588,7 +588,7 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: disable=too-many-public-m
     @provide_session
     def task_stats(self, session=None):
         """Task Statistics"""
-        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
 
         if not allowed_dag_ids:
             return wwwutils.json_response({})
@@ -702,7 +702,7 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: disable=too-many-public-m
     @provide_session
     def last_dagruns(self, session=None):
         """Last DAG runs"""
-        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
 
         if 'all_dags' in allowed_dag_ids:
             allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -1385,7 +1385,7 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: disable=too-many-public-m
     @provide_session
     def blocked(self, session=None):
         """Mark Dag Blocked."""
-        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
 
         if 'all_dags' in allowed_dag_ids:
             allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -2287,7 +2287,6 @@ class Airflow(AirflowBaseView):  # noqa: D101  pylint: disable=too-many-public-m
             return response
 
         task = dag.get_task(task_id)
-
         try:
             url = task.get_extra_links(dttm, link_name)
         except ValueError as err:
@@ -2416,7 +2415,7 @@ class DagFilter(BaseFilter):
     def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument
         if current_app.appbuilder.sm.has_all_dags_access():
             return query
-        filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
         return query.filter(self.model.dag_id.in_(filter_dag_ids))
 
 
@@ -3136,9 +3135,9 @@ class DagModelView(AirflowModelView):
             dag_ids_query = dag_ids_query.filter(DagModel.is_paused)
             owners_query = owners_query.filter(DagModel.is_paused)
 
-        filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+        filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
         # pylint: disable=no-member
-        if 'all_dags' not in filter_dag_ids:
+        if not bool({'all_dags', 'Dag'}.intersection(filter_dag_ids)):
             dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
             owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
         # pylint: enable=no-member
diff --git a/tests/www/test_security.py b/tests/www/test_security.py
index 2399dca..fc7f57a 100644
--- a/tests/www/test_security.py
+++ b/tests/www/test_security.py
@@ -20,15 +20,17 @@ import logging
 import unittest
 from unittest import mock
 
-from flask import Flask
-from flask_appbuilder import SQLA, AppBuilder, Model, expose, has_access
+from flask_appbuilder import SQLA, Model, expose, has_access
 from flask_appbuilder.security.sqla import models as sqla_models
 from flask_appbuilder.views import BaseView, ModelView
 from sqlalchemy import Column, Date, Float, Integer, String
 
+from airflow import settings
 from airflow.exceptions import AirflowException
-from airflow.www.security import AirflowSecurityManager
+from airflow.models import DagModel
+from airflow.www import app as application
 from airflow.www.utils import CustomSQLAInterface
+from tests.test_utils.db import clear_db_runs
 from tests.test_utils.mock_security_manager import MockSecurityManager
 
 READ_WRITE = {'can_dag_read', 'can_dag_edit'}
@@ -66,22 +68,24 @@ class SomeBaseView(BaseView):
 
 
 class TestSecurity(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        settings.configure_orm()
+        cls.session = settings.Session
+        cls.app = application.create_app(testing=True)
+        cls.appbuilder = cls.app.appbuilder  # pylint: disable=no-member
+        cls.app.config['WTF_CSRF_ENABLED'] = False
+        cls.security_manager = cls.appbuilder.sm
+        cls.role_admin = cls.security_manager.find_role('Admin')
+        cls.user = cls.appbuilder.sm.add_user(
+            'admin', 'admin', 'user', 'admin@fab.org', cls.role_admin, 'general'
+        )
+
     def setUp(self):
-        self.app = Flask(__name__)
-        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
-        self.app.config['SECRET_KEY'] = 'secret_key'
-        self.app.config['CSRF_ENABLED'] = False
-        self.app.config['WTF_CSRF_ENABLED'] = False
         self.db = SQLA(self.app)
-        self.appbuilder = AppBuilder(self.app,
-                                     self.db.session,
-                                     security_manager_class=AirflowSecurityManager)
-        self.security_manager = self.appbuilder.sm
         self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
         self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
-        role_admin = self.security_manager.find_role('Admin')
-        self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', 'admin@fab.org',
-                                                role_admin, 'general')
+
         log.debug("Complete setup!")
 
     def expect_user_is_in_role(self, user, rolename):
@@ -112,13 +116,14 @@ class TestSecurity(unittest.TestCase):
             self.user)
 
     def tearDown(self):
+        clear_db_runs()
         self.appbuilder = None
         self.app = None
         self.db = None
         log.debug("Complete teardown!")
 
     def test_init_role_baseview(self):
-        role_name = 'MyRole1'
+        role_name = 'MyRole3'
         role_perms = ['can_some_action']
         role_vms = ['SomeBaseView']
         self.security_manager.init_role(role_name, role_vms, role_perms)
@@ -159,7 +164,7 @@ class TestSecurity(unittest.TestCase):
 
     @mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles')
     def test_get_all_permissions_views(self, mock_get_user_roles):
-        role_name = 'MyRole1'
+        role_name = 'MyRole5'
         role_perms = ['can_some_action']
         role_vms = ['SomeBaseView']
         self.security_manager.init_role(role_name, role_vms, role_perms)
@@ -174,23 +179,27 @@ class TestSecurity(unittest.TestCase):
         self.assertEqual(len(self.security_manager
                              .get_all_permissions_views()), 0)
 
-    @mock.patch('airflow.www.security.AirflowSecurityManager'
-                '.get_all_permissions_views')
-    @mock.patch('airflow.www.security.AirflowSecurityManager'
-                '.get_user_roles')
-    def test_get_accessible_dag_ids(self, mock_get_user_roles,
-                                    mock_get_all_permissions_views):
-        user = mock.MagicMock()
+    def test_get_accessible_dag_ids(self):
         role_name = 'MyRole1'
-        role_perms = ['can_dag_read']
-        role_vms = ['dag_id']
-        self.security_manager.init_role(role_name, role_vms, role_perms)
+        permission_action = ['can_dag_read']
+        dag_id = 'dag_id'
+        username = "Mr. User"
+        self.security_manager.init_role(role_name, [], [])
+        self.security_manager.sync_perm_for_dag(  # type: ignore  # pylint: disable=no-member
+            dag_id, access_control={role_name: permission_action}
+        )
         role = self.security_manager.find_role(role_name)
-        user.roles = [role]
-        user.is_anonymous = False
-        mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')}
-
-        mock_get_user_roles.return_value = [role]
+        user = self.security_manager.add_user(
+            username=username,
+            first_name=username,
+            last_name=username,
+            email=f"{username}@fab.org",
+            role=role,
+            password=username,
+        )
+        dag_model = DagModel(dag_id="dag_id", fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *")
+        self.session.add(dag_model)
+        self.session.commit()
         self.assertEqual(self.security_manager
                          .get_accessible_dag_ids(user), {'dag_id'})
 
@@ -235,8 +244,17 @@ class TestSecurity(unittest.TestCase):
             'can_varimport',  # a real permission, but not a member of DAG_PERMS
             'can_eat_pudding',  # clearly not a real permission
         ]
+        username = "Mrs. User"
+        user = self.security_manager.add_user(
+            username=username,
+            first_name=username,
+            last_name=username,
+            email=f"{username}@fab.org",
+            role=self.role_admin,
+            password=username,
+        )
         for permission in invalid_permissions:
-            self.expect_user_is_in_role(self.user, rolename='team-a')
+            self.expect_user_is_in_role(user, rolename='team-a')
             with self.assertRaises(AirflowException) as context:
                 self.security_manager.sync_perm_for_dag(
                     'access_control_test',
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index a1b412e..761208e 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -441,7 +441,7 @@ class TestAirflowBaseViews(TestBase):
             state=State.RUNNING)
 
     def test_index(self):
-        with assert_queries_count(40):
+        with assert_queries_count(43):
             resp = self.client.get('/', follow_redirects=True)
         self.check_content_in_response('DAGs', resp)