You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/01 16:38:27 UTC
[airflow] branch master updated: Replace get accessible dag ids
(#11027)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 427a4a8 Replace get accessible dag ids (#11027)
427a4a8 is described below
commit 427a4a8f01c414ab571578bb6b8fbe5a8c6b32ef
Author: James Timmins <ja...@astronomer.io>
AuthorDate: Thu Oct 1 09:37:00 2020 -0700
Replace get accessible dag ids (#11027)
---
airflow/www/security.py | 64 +++++++++++++++++------------------
airflow/www/views.py | 21 ++++++------
tests/www/test_security.py | 84 ++++++++++++++++++++++++++++------------------
tests/www/test_views.py | 2 +-
4 files changed, 93 insertions(+), 78 deletions(-)
diff --git a/airflow/www/security.py b/airflow/www/security.py
index 355ccf0..20686b7 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -22,7 +22,9 @@ from typing import Set
from flask import current_app, g
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.security.sqla.manager import SecurityManager
+from flask_appbuilder.security.sqla.models import PermissionView, Role, User
from sqlalchemy import and_, or_
+from sqlalchemy.orm import joinedload
from airflow import models
from airflow.exceptions import AirflowException
@@ -41,7 +43,9 @@ EXISTING_ROLES = {
CAN_CREATE = 'can_create'
CAN_READ = 'can_read'
+CAN_DAG_READ = 'can_dag_read'
CAN_EDIT = 'can_edit'
+CAN_DAG_EDIT = 'can_dag_edit'
CAN_DELETE = 'can_delete'
@@ -276,60 +280,54 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
def get_readable_dags(self, user):
"""Gets the DAGs readable by authenticated user."""
- return self.get_accessible_dags(CAN_READ, user)
+ return self.get_accessible_dags([CAN_READ, CAN_DAG_READ], user)
def get_editable_dags(self, user):
"""Gets the DAGs editable by authenticated user."""
- return self.get_accessible_dags(CAN_EDIT, user)
+ return self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT], user)
- def get_readable_dag_ids(self, user):
+ def get_readable_dag_ids(self, user) -> Set[str]:
"""Gets the DAG IDs readable by authenticated user."""
- return [dag.dag_id for dag in self.get_readable_dags(user)]
+ return set(dag.dag_id for dag in self.get_readable_dags(user))
- def get_editable_dag_ids(self, user):
+ def get_editable_dag_ids(self, user) -> Set[str]:
"""Gets the DAG IDs editable by authenticated user."""
- return [dag.dag_id for dag in self.get_editable_dags(user)]
+ return set(dag.dag_id for dag in self.get_editable_dags(user))
+
+ def get_accessible_dag_ids(self, user) -> Set[str]:
+ """Gets the DAG IDs editable or readable by authenticated user."""
+ accessible_dags = self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT, CAN_READ, CAN_DAG_READ], user)
+ return set(dag.dag_id for dag in accessible_dags)
@provide_session
- def get_accessible_dags(self, user_action, user, session=None):
+ def get_accessible_dags(self, user_actions, user, session=None):
"""Generic function to get readable or writable DAGs for authenticated user."""
if user.is_anonymous:
return set()
+ user_query = (
+ session.query(User)
+ .options(
+ joinedload(User.roles)
+ .subqueryload(Role.permissions)
+ .options(joinedload(PermissionView.permission), joinedload(PermissionView.view_menu))
+ )
+ .filter(User.id == user.id)
+ .first()
+ )
resources = set()
- for role in user.roles:
+ for role in user_query.roles:
for permission in role.permissions:
resource = permission.view_menu.name
action = permission.permission.name
- if action == user_action:
+ if action in user_actions:
resources.add(resource)
- if 'Dag' in resources:
+
+ if bool({'Dag', 'all_dags'}.intersection(resources)):
return session.query(DagModel)
return session.query(DagModel).filter(DagModel.dag_id.in_(resources))
- def get_accessible_dag_ids(self, username=None) -> Set[str]:
- """
- Return a set of dags that user has access to(either read or write).
-
- :param username: Name of the user.
- :return: A set of dag ids that the user could access.
- """
- if not username:
- username = g.user
-
- if username.is_anonymous or 'Public' in username.roles:
- # return an empty set if the role is public
- return set()
-
- roles = {role.name for role in username.roles}
- if {'Admin', 'Viewer', 'User', 'Op'} & roles:
- return self.DAG_VMS
-
- user_perms_views = self.get_all_permissions_views()
- # return a set of all dags that the user could access
- return {view for perm, view in user_perms_views if perm in self.DAG_PERMS}
-
def has_access(self, permission, view_name, user=None) -> bool:
"""
Verify whether a given user could perform certain permission
@@ -414,7 +412,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
def _merge_perm(self, permission_name, view_menu_name):
"""
- Add the new permission , view_menu to ab_permission_view_role if not exists.
+ Add the new (permission, view_menu) to assoc_permissionview_role if it doesn't exist.
It will add the related entry to ab_permission
and ab_view_menu two meta tables as well.
diff --git a/airflow/www/views.py b/airflow/www/views.py
index b6b1978..95a949b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -33,8 +33,8 @@ import lazy_object_proxy
import nvd3
import sqlalchemy as sqla
from flask import (
- Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request,
- session as flask_session, url_for,
+ Markup, Response, current_app, escape, flash, g, jsonify, make_response, redirect, render_template,
+ request, session as flask_session, url_for,
)
from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name
from flask_appbuilder.actions import action
@@ -442,7 +442,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
end = start + dags_per_page
# Get all the dag id the user could access
- filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
with create_session() as session:
# read orm_dags from the db
@@ -543,7 +543,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
"""Dag statistics."""
dr = models.DagRun
- allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -588,7 +588,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
@provide_session
def task_stats(self, session=None):
"""Task Statistics"""
- allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
if not allowed_dag_ids:
return wwwutils.json_response({})
@@ -702,7 +702,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
@provide_session
def last_dagruns(self, session=None):
"""Last DAG runs"""
- allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -1385,7 +1385,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
@provide_session
def blocked(self, session=None):
"""Mark Dag Blocked."""
- allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -2287,7 +2287,6 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m
return response
task = dag.get_task(task_id)
-
try:
url = task.get_extra_links(dttm, link_name)
except ValueError as err:
@@ -2416,7 +2415,7 @@ class DagFilter(BaseFilter):
def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument
if current_app.appbuilder.sm.has_all_dags_access():
return query
- filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
return query.filter(self.model.dag_id.in_(filter_dag_ids))
@@ -3136,9 +3135,9 @@ class DagModelView(AirflowModelView):
dag_ids_query = dag_ids_query.filter(DagModel.is_paused)
owners_query = owners_query.filter(DagModel.is_paused)
- filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
+ filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
# pylint: disable=no-member
- if 'all_dags' not in filter_dag_ids:
+ if not bool({'all_dags', 'Dag'}.intersection(filter_dag_ids)):
dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
# pylint: enable=no-member
diff --git a/tests/www/test_security.py b/tests/www/test_security.py
index 2399dca..fc7f57a 100644
--- a/tests/www/test_security.py
+++ b/tests/www/test_security.py
@@ -20,15 +20,17 @@ import logging
import unittest
from unittest import mock
-from flask import Flask
-from flask_appbuilder import SQLA, AppBuilder, Model, expose, has_access
+from flask_appbuilder import SQLA, Model, expose, has_access
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.views import BaseView, ModelView
from sqlalchemy import Column, Date, Float, Integer, String
+from airflow import settings
from airflow.exceptions import AirflowException
-from airflow.www.security import AirflowSecurityManager
+from airflow.models import DagModel
+from airflow.www import app as application
from airflow.www.utils import CustomSQLAInterface
+from tests.test_utils.db import clear_db_runs
from tests.test_utils.mock_security_manager import MockSecurityManager
READ_WRITE = {'can_dag_read', 'can_dag_edit'}
@@ -66,22 +68,24 @@ class SomeBaseView(BaseView):
class TestSecurity(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ settings.configure_orm()
+ cls.session = settings.Session
+ cls.app = application.create_app(testing=True)
+ cls.appbuilder = cls.app.appbuilder # pylint: disable=no-member
+ cls.app.config['WTF_CSRF_ENABLED'] = False
+ cls.security_manager = cls.appbuilder.sm
+ cls.role_admin = cls.security_manager.find_role('Admin')
+ cls.user = cls.appbuilder.sm.add_user(
+ 'admin', 'admin', 'user', 'admin@fab.org', cls.role_admin, 'general'
+ )
+
def setUp(self):
- self.app = Flask(__name__)
- self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
- self.app.config['SECRET_KEY'] = 'secret_key'
- self.app.config['CSRF_ENABLED'] = False
- self.app.config['WTF_CSRF_ENABLED'] = False
self.db = SQLA(self.app)
- self.appbuilder = AppBuilder(self.app,
- self.db.session,
- security_manager_class=AirflowSecurityManager)
- self.security_manager = self.appbuilder.sm
self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
- role_admin = self.security_manager.find_role('Admin')
- self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', 'admin@fab.org',
- role_admin, 'general')
+
log.debug("Complete setup!")
def expect_user_is_in_role(self, user, rolename):
@@ -112,13 +116,14 @@ class TestSecurity(unittest.TestCase):
self.user)
def tearDown(self):
+ clear_db_runs()
self.appbuilder = None
self.app = None
self.db = None
log.debug("Complete teardown!")
def test_init_role_baseview(self):
- role_name = 'MyRole1'
+ role_name = 'MyRole3'
role_perms = ['can_some_action']
role_vms = ['SomeBaseView']
self.security_manager.init_role(role_name, role_vms, role_perms)
@@ -159,7 +164,7 @@ class TestSecurity(unittest.TestCase):
@mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles')
def test_get_all_permissions_views(self, mock_get_user_roles):
- role_name = 'MyRole1'
+ role_name = 'MyRole5'
role_perms = ['can_some_action']
role_vms = ['SomeBaseView']
self.security_manager.init_role(role_name, role_vms, role_perms)
@@ -174,23 +179,27 @@ class TestSecurity(unittest.TestCase):
self.assertEqual(len(self.security_manager
.get_all_permissions_views()), 0)
- @mock.patch('airflow.www.security.AirflowSecurityManager'
- '.get_all_permissions_views')
- @mock.patch('airflow.www.security.AirflowSecurityManager'
- '.get_user_roles')
- def test_get_accessible_dag_ids(self, mock_get_user_roles,
- mock_get_all_permissions_views):
- user = mock.MagicMock()
+ def test_get_accessible_dag_ids(self):
role_name = 'MyRole1'
- role_perms = ['can_dag_read']
- role_vms = ['dag_id']
- self.security_manager.init_role(role_name, role_vms, role_perms)
+ permission_action = ['can_dag_read']
+ dag_id = 'dag_id'
+ username = "Mr. User"
+ self.security_manager.init_role(role_name, [], [])
+ self.security_manager.sync_perm_for_dag( # type: ignore # pylint: disable=no-member
+ dag_id, access_control={role_name: permission_action}
+ )
role = self.security_manager.find_role(role_name)
- user.roles = [role]
- user.is_anonymous = False
- mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')}
-
- mock_get_user_roles.return_value = [role]
+ user = self.security_manager.add_user(
+ username=username,
+ first_name=username,
+ last_name=username,
+ email=f"{username}@fab.org",
+ role=role,
+ password=username,
+ )
+ dag_model = DagModel(dag_id="dag_id", fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *")
+ self.session.add(dag_model)
+ self.session.commit()
self.assertEqual(self.security_manager
.get_accessible_dag_ids(user), {'dag_id'})
@@ -235,8 +244,17 @@ class TestSecurity(unittest.TestCase):
'can_varimport', # a real permission, but not a member of DAG_PERMS
'can_eat_pudding', # clearly not a real permission
]
+ username = "Mrs. User"
+ user = self.security_manager.add_user(
+ username=username,
+ first_name=username,
+ last_name=username,
+ email=f"{username}@fab.org",
+ role=self.role_admin,
+ password=username,
+ )
for permission in invalid_permissions:
- self.expect_user_is_in_role(self.user, rolename='team-a')
+ self.expect_user_is_in_role(user, rolename='team-a')
with self.assertRaises(AirflowException) as context:
self.security_manager.sync_perm_for_dag(
'access_control_test',
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index a1b412e..761208e 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -441,7 +441,7 @@ class TestAirflowBaseViews(TestBase):
state=State.RUNNING)
def test_index(self):
- with assert_queries_count(40):
+ with assert_queries_count(43):
resp = self.client.get('/', follow_redirects=True)
self.check_content_in_response('DAGs', resp)