You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/11/09 16:13:55 UTC
(airflow) branch main updated: Use resource ID to perform fined-grained access in views (#35380)
This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8b1cade0d1 Use resource ID to perform fined-grained access in views (#35380)
8b1cade0d1 is described below
commit 8b1cade0d1b191a34c513d314a9c13c2e10538bd
Author: Vincent <97...@users.noreply.github.com>
AuthorDate: Thu Nov 9 11:13:48 2023 -0500
Use resource ID to perform fined-grained access in views (#35380)
---
airflow/auth/managers/fab/fab_auth_manager.py | 2 +-
airflow/auth/managers/models/resource_details.py | 2 +-
airflow/www/auth.py | 131 +++++++++-
airflow/www/security_manager.py | 137 ++++++++---
airflow/www/views.py | 195 ++++++++-------
tests/auth/managers/fab/test_fab_auth_manager.py | 7 -
tests/www/test_auth.py | 295 ++++++++++++++++++-----
tests/www/views/test_views.py | 30 ---
tests/www/views/test_views_dagrun.py | 17 +-
tests/www/views/test_views_decorators.py | 11 -
tests/www/views/test_views_pool.py | 11 +
tests/www/views/test_views_tasks.py | 6 +-
tests/www/views/test_views_variable.py | 25 +-
13 files changed, 619 insertions(+), 250 deletions(-)
diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py
index beb11da06c..6305406d2a 100644
--- a/airflow/auth/managers/fab/fab_auth_manager.py
+++ b/airflow/auth/managers/fab/fab_auth_manager.py
@@ -98,6 +98,7 @@ _MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE: dict[DagAccessEntity, tuple[str, ..
DagAccessEntity.DEPENDENCIES: (RESOURCE_DAG_DEPENDENCIES,),
DagAccessEntity.IMPORT_ERRORS: (RESOURCE_IMPORT_ERROR,),
DagAccessEntity.RUN: (RESOURCE_DAG_RUN,),
+ DagAccessEntity.SLA_MISS: (RESOURCE_SLA_MISS,),
# RESOURCE_TASK_INSTANCE has been originally misused. RESOURCE_TASK_INSTANCE referred to task definition
# AND task instances without making the difference
# To be backward compatible, we translate DagAccessEntity.TASK_INSTANCE to RESOURCE_TASK_INSTANCE AND
@@ -118,7 +119,6 @@ _MAP_ACCESS_VIEW_TO_FAB_RESOURCE_TYPE = {
AccessView.PLUGINS: RESOURCE_PLUGIN,
AccessView.PROVIDERS: RESOURCE_PROVIDER,
AccessView.TRIGGERS: RESOURCE_TRIGGER,
- AccessView.SLA: RESOURCE_SLA_MISS,
AccessView.WEBSITE: RESOURCE_WEBSITE,
}
diff --git a/airflow/auth/managers/models/resource_details.py b/airflow/auth/managers/models/resource_details.py
index a64ef68978..eb2f81644e 100644
--- a/airflow/auth/managers/models/resource_details.py
+++ b/airflow/auth/managers/models/resource_details.py
@@ -72,7 +72,6 @@ class AccessView(Enum):
PLUGINS = "PLUGINS"
PROVIDERS = "PROVIDERS"
TRIGGERS = "TRIGGERS"
- SLA = "SLA"
WEBSITE = "WEBSITE"
@@ -84,6 +83,7 @@ class DagAccessEntity(Enum):
DEPENDENCIES = "DEPENDENCIES"
IMPORT_ERRORS = "IMPORT_ERRORS"
RUN = "RUN"
+ SLA_MISS = "SLA_MISS"
TASK = "TASK"
TASK_INSTANCE = "TASK_INSTANCE"
TASK_RESCHEDULE = "TASK_RESCHEDULE"
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index 8e9231fdec..e285d517ba 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -16,18 +16,27 @@
# under the License.
from __future__ import annotations
+import functools
import logging
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Callable, Sequence, TypeVar, cast
-from flask import flash, redirect, render_template, request
+from flask import flash, redirect, render_template, request, url_for
+from flask_appbuilder._compat import as_unicode
+from flask_appbuilder.const import (
+ FLAMSG_ERR_SEC_ACCESS_DENIED,
+ LOGMSG_ERR_SEC_ACCESS_DENIED,
+ PERMISSION_PREFIX,
+)
from airflow.auth.managers.models.resource_details import (
AccessView,
ConnectionDetails,
DagAccessEntity,
DagDetails,
+ PoolDetails,
+ VariableDetails,
)
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
@@ -36,7 +45,9 @@ from airflow.www.extensions.init_auth_manager import get_auth_manager
if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod
+ from airflow.models import DagRun, Pool, SlaMiss, TaskInstance, Variable
from airflow.models.connection import Connection
+ from airflow.models.xcom import BaseXCom
T = TypeVar("T", bound=Callable)
@@ -69,6 +80,47 @@ def has_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable
return _has_access_fab(permissions)
+def has_access_with_pk(f):
+ """
+ This decorator is used to check permissions on views.
+
+ The implementation is very similar from
+ https://github.com/dpgaspar/Flask-AppBuilder/blob/c6fecdc551629e15467fde5d06b4437379d90592/flask_appbuilder/security/decorators.py#L134
+
+ The difference is that this decorator will pass the resource ID to check permissions. It allows
+ fined-grained access control using resource IDs.
+ """
+ if hasattr(f, "_permission_name"):
+ permission_str = f._permission_name
+ else:
+ permission_str = f.__name__
+
+ def wraps(self, *args, **kwargs):
+ permission_str = f"{PERMISSION_PREFIX}{f._permission_name}"
+ if self.method_permission_name:
+ _permission_name = self.method_permission_name.get(f.__name__)
+ if _permission_name:
+ permission_str = f"{PERMISSION_PREFIX}{_permission_name}"
+ if permission_str in self.base_permissions and self.appbuilder.sm.has_access(
+ action_name=permission_str,
+ resource_name=self.class_permission_name,
+ resource_pk=kwargs.get("pk"),
+ ):
+ return f(self, *args, **kwargs)
+ else:
+ log.warning(LOGMSG_ERR_SEC_ACCESS_DENIED.format(permission_str, self.__class__.__name__))
+ flash(as_unicode(FLAMSG_ERR_SEC_ACCESS_DENIED), "danger")
+ return redirect(
+ url_for(
+ self.appbuilder.sm.auth_view.__class__.__name__ + ".login",
+ next=request.url,
+ )
+ )
+
+ f._permission_name = permission_str
+ return functools.update_wrapper(wraps, f)
+
+
def _has_access_no_details(is_authorized_callback: Callable[[], bool]) -> Callable[[T], T]:
"""
Check current user's permissions against required permissions.
@@ -208,14 +260,87 @@ def has_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None
return has_access_decorator
+def has_access_dag_entities(method: ResourceMethod, access_entity: DagAccessEntity) -> Callable[[T], T]:
+ def has_access_decorator(func: T):
+ @wraps(func)
+ def decorated(*args, **kwargs):
+ items: set[SlaMiss | BaseXCom | DagRun | TaskInstance] = set(args[1])
+ dags_details = [DagDetails(id=item.dag_id) for item in items if item is not None]
+ is_authorized = all(
+ [
+ get_auth_manager().is_authorized_dag(
+ method=method, access_entity=access_entity, details=dag_details
+ )
+ for dag_details in dags_details
+ ]
+ )
+ return _has_access(
+ is_authorized=is_authorized,
+ func=func,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ return cast(T, decorated)
+
+ return has_access_decorator
+
+
def has_access_dataset(method: ResourceMethod) -> Callable[[T], T]:
"""Check current user's permissions against required permissions for datasets."""
return _has_access_no_details(lambda: get_auth_manager().is_authorized_dataset(method=method))
+def has_access_pool(method: ResourceMethod) -> Callable[[T], T]:
+ def has_access_decorator(func: T):
+ @wraps(func)
+ def decorated(*args, **kwargs):
+ pools: set[Pool] = set(args[1])
+ pools_details = [PoolDetails(name=pool.pool) for pool in pools]
+ is_authorized = all(
+ [
+ get_auth_manager().is_authorized_pool(method=method, details=pool_details)
+ for pool_details in pools_details
+ ]
+ )
+ return _has_access(
+ is_authorized=is_authorized,
+ func=func,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ return cast(T, decorated)
+
+ return has_access_decorator
+
+
def has_access_variable(method: ResourceMethod) -> Callable[[T], T]:
- """Check current user's permissions against required permissions for variables."""
- return _has_access_no_details(lambda: get_auth_manager().is_authorized_variable(method=method))
+ def has_access_decorator(func: T):
+ @wraps(func)
+ def decorated(*args, **kwargs):
+ if len(args) == 1:
+ # No items provided
+ is_authorized = get_auth_manager().is_authorized_variable(method=method)
+ else:
+ variables: set[Variable] = set(args[1])
+ variables_details = [VariableDetails(key=variable.key) for variable in variables]
+ is_authorized = all(
+ [
+ get_auth_manager().is_authorized_variable(method=method, details=variable_details)
+ for variable_details in variables_details
+ ]
+ )
+ return _has_access(
+ is_authorized=is_authorized,
+ func=func,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ return cast(T, decorated)
+
+ return has_access_decorator
def has_access_view(access_view: AccessView = AccessView.WEBSITE) -> Callable[[T], T]:
diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py
index 8fcd4f5d78..ffa65ce71d 100644
--- a/airflow/www/security_manager.py
+++ b/airflow/www/security_manager.py
@@ -16,16 +16,27 @@
# under the License.
from __future__ import annotations
+import json
from functools import cached_property
from typing import TYPE_CHECKING, Callable
from flask import g
+from sqlalchemy import select
from airflow.auth.managers.fab.security_manager.constants import EXISTING_ROLES as FAB_EXISTING_ROLES
-from airflow.auth.managers.models.resource_details import AccessView, DagAccessEntity
+from airflow.auth.managers.models.resource_details import (
+ AccessView,
+ ConnectionDetails,
+ DagAccessEntity,
+ DagDetails,
+ PoolDetails,
+ VariableDetails,
+)
from airflow.auth.managers.utils.fab import (
get_method_from_fab_action_map,
)
+from airflow.exceptions import AirflowException
+from airflow.models import Connection, DagRun, Pool, TaskInstance, Variable
from airflow.security.permissions import (
ACTION_CAN_ACCESS_MENU,
ACTION_CAN_READ,
@@ -54,6 +65,7 @@ from airflow.security.permissions import (
RESOURCE_XCOM,
)
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.fab_security.sqla.manager import SecurityManager
from airflow.www.utils import CustomSQLAInterface
@@ -61,6 +73,8 @@ from airflow.www.utils import CustomSQLAInterface
EXISTING_ROLES = FAB_EXISTING_ROLES
if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
from airflow.auth.managers.models.base_user import BaseUser
@@ -82,7 +96,9 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
if view and getattr(view, "datamodel", None):
view.datamodel = CustomSQLAInterface(view.datamodel.obj)
- def has_access(self, action_name: str, resource_name: str, user=None) -> bool:
+ def has_access(
+ self, action_name: str, resource_name: str, user=None, resource_pk: str | None = None
+ ) -> bool:
"""
Verify whether a given user could perform a certain action on the given resource.
@@ -91,26 +107,18 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
This function is called by FAB when accessing a view. See
https://github.com/dpgaspar/Flask-AppBuilder/blob/c6fecdc551629e15467fde5d06b4437379d90592/flask_appbuilder/security/decorators.py#L134
- The resource ID (e.g. the connection ID) is not passed to this function (see above link). Therefore,
- it is not possible to perform fine-grained access authorization with the resource ID yet. In other
- words, we can only verify the user has access to all connections and not to a specific connection.
- To make it happen, we either need to:
- - Override all views in 'airflow/www/views.py' inheriting from `AirflowModelView` and use a custom
- `has_access` decorator.
- - Wait for the new Airflow UI to come.
-
:param action_name: action_name on resource (e.g can_read, can_edit).
:param resource_name: name of view-menu or resource.
:param user: user
+ :param resource_pk: the resource primary key (e.g. the connection ID)
:return: Whether user could perform certain action on the resource.
- :rtype bool
"""
if not user:
user = g.user
is_authorized_method = self._get_auth_manager_is_authorized_method(resource_name)
if is_authorized_method:
- return is_authorized_method(action_name, user)
+ return is_authorized_method(action_name, resource_pk, user)
else:
# This means the page the user is trying to access is specific to the auth manager used
# Example: the user list view in FabAuthManager
@@ -127,7 +135,10 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
return None, None
@cached_property
- def _auth_manager_is_authorized_map(self) -> dict[str, Callable[[str, BaseUser | None], bool]]:
+ @provide_session
+ def _auth_manager_is_authorized_map(
+ self, session: Session = NEW_SESSION
+ ) -> dict[str, Callable[[str, str | None, BaseUser | None], bool]]:
"""
Return the map associating a FAB resource name to the corresponding auth manager is_authorized_ API.
@@ -136,90 +147,146 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
auth_manager = get_auth_manager()
methods = get_method_from_fab_action_map()
+ def get_connection_id(resource_pk):
+ if not resource_pk:
+ return None
+ connection = session.scalar(select(Connection).where(Connection.id == resource_pk).limit(1))
+ if not connection:
+ raise AirflowException("Connection not found")
+ return connection.conn_id
+
+ def get_dag_id_from_dagrun_id(resource_pk):
+ if not resource_pk:
+ return None
+ dagrun = session.scalar(select(DagRun).where(DagRun.id == resource_pk).limit(1))
+ if not dagrun:
+ raise AirflowException("DagRun not found")
+ return dagrun.dag_id
+
+ def get_dag_id_from_task_instance(resource_pk):
+ if not resource_pk:
+ return None
+ composite_pk = json.loads(resource_pk)
+ ti = session.scalar(
+ select(DagRun)
+ .where(
+ TaskInstance.dag_id == composite_pk[0],
+ TaskInstance.task_id == composite_pk[1],
+ TaskInstance.run_id == composite_pk[2],
+ TaskInstance.map_index >= composite_pk[3],
+ )
+ .limit(1)
+ )
+ if not ti:
+ raise AirflowException("Task instance not found")
+ return ti.dag_id
+
+ def get_pool_name(resource_pk):
+ if not resource_pk:
+ return None
+ pool = session.scalar(select(Pool).where(Pool.id == resource_pk).limit(1))
+ if not pool:
+ raise AirflowException("Pool not found")
+ return pool.pool
+
+ def get_variable_key(resource_pk):
+ if not resource_pk:
+ return None
+ variable = session.scalar(select(Variable).where(Variable.id == resource_pk).limit(1))
+ if not variable:
+ raise AirflowException("Connection not found")
+ return variable.key
+
return {
- RESOURCE_AUDIT_LOG: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_AUDIT_LOG: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.AUDIT_LOG,
user=user,
),
- RESOURCE_CLUSTER_ACTIVITY: lambda action, user: auth_manager.is_authorized_view(
+ RESOURCE_CLUSTER_ACTIVITY: lambda action, resource_pk, user: auth_manager.is_authorized_view(
access_view=AccessView.CLUSTER_ACTIVITY,
user=user,
),
- RESOURCE_CONFIG: lambda action, user: auth_manager.is_authorized_configuration(
+ RESOURCE_CONFIG: lambda action, resource_pk, user: auth_manager.is_authorized_configuration(
method=methods[action],
user=user,
),
- RESOURCE_CONNECTION: lambda action, user: auth_manager.is_authorized_connection(
+ RESOURCE_CONNECTION: lambda action, resource_pk, user: auth_manager.is_authorized_connection(
method=methods[action],
+ details=ConnectionDetails(conn_id=get_connection_id(resource_pk)),
user=user,
),
- RESOURCE_DAG: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_DAG: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
user=user,
),
- RESOURCE_DAG_CODE: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_DAG_CODE: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.CODE,
user=user,
),
- RESOURCE_DAG_DEPENDENCIES: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_DAG_DEPENDENCIES: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.DEPENDENCIES,
user=user,
),
- RESOURCE_DAG_RUN: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_DAG_RUN: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.RUN,
+ details=DagDetails(id=get_dag_id_from_dagrun_id(resource_pk)),
user=user,
),
- RESOURCE_DATASET: lambda action, user: auth_manager.is_authorized_dataset(
+ RESOURCE_DATASET: lambda action, resource_pk, user: auth_manager.is_authorized_dataset(
method=methods[action],
user=user,
),
- RESOURCE_DOCS: lambda action, user: auth_manager.is_authorized_view(
+ RESOURCE_DOCS: lambda action, resource_pk, user: auth_manager.is_authorized_view(
access_view=AccessView.DOCS,
user=user,
),
- RESOURCE_PLUGIN: lambda action, user: auth_manager.is_authorized_view(
+ RESOURCE_PLUGIN: lambda action, resource_pk, user: auth_manager.is_authorized_view(
access_view=AccessView.PLUGINS,
user=user,
),
- RESOURCE_JOB: lambda action, user: auth_manager.is_authorized_view(
+ RESOURCE_JOB: lambda action, resource_pk, user: auth_manager.is_authorized_view(
access_view=AccessView.JOBS,
user=user,
),
- RESOURCE_POOL: lambda action, user: auth_manager.is_authorized_pool(
+ RESOURCE_POOL: lambda action, resource_pk, user: auth_manager.is_authorized_pool(
method=methods[action],
+ details=PoolDetails(name=get_pool_name(resource_pk)),
user=user,
),
- RESOURCE_PROVIDER: lambda action, user: auth_manager.is_authorized_view(
+ RESOURCE_PROVIDER: lambda action, resource_pk, user: auth_manager.is_authorized_view(
access_view=AccessView.PROVIDERS,
user=user,
),
- RESOURCE_SLA_MISS: lambda action, user: auth_manager.is_authorized_view(
- access_view=AccessView.SLA,
+ RESOURCE_SLA_MISS: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
+ method=methods[action],
+ access_entity=DagAccessEntity.SLA_MISS,
user=user,
),
- RESOURCE_TASK_INSTANCE: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_TASK_INSTANCE: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.TASK_INSTANCE,
+ details=DagDetails(id=get_dag_id_from_task_instance(resource_pk)),
user=user,
),
- RESOURCE_TASK_RESCHEDULE: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_TASK_RESCHEDULE: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.TASK_RESCHEDULE,
user=user,
),
- RESOURCE_TRIGGER: lambda action, user: auth_manager.is_authorized_view(
+ RESOURCE_TRIGGER: lambda action, resource_pk, user: auth_manager.is_authorized_view(
access_view=AccessView.TRIGGERS,
user=user,
),
- RESOURCE_VARIABLE: lambda action, user: auth_manager.is_authorized_variable(
+ RESOURCE_VARIABLE: lambda action, resource_pk, user: auth_manager.is_authorized_variable(
method=methods[action],
+ details=VariableDetails(key=get_variable_key(resource_pk)),
user=user,
),
- RESOURCE_XCOM: lambda action, user: auth_manager.is_authorized_dag(
+ RESOURCE_XCOM: lambda action, resource_pk, user: auth_manager.is_authorized_dag(
method=methods[action],
access_entity=DagAccessEntity.XCOM,
user=user,
@@ -239,6 +306,6 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
def _is_authorized_category_menu(self, category: str) -> Callable:
items = {item.name for item in self.appbuilder.menu.find(category).childs}
- return lambda action, user: any(
+ return lambda action, resource_pk, user: any(
self._get_auth_manager_is_authorized_method(fab_resource_name=item) for item in items
)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 46debf8a69..a7cc8c3c1b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -32,10 +32,10 @@ import traceback
import warnings
from bisect import insort_left
from collections import defaultdict
-from functools import cached_property, wraps
+from functools import cached_property
from json import JSONDecodeError
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Collection, Iterator, Mapping, MutableMapping, Sequence
+from typing import TYPE_CHECKING, Any, Collection, Iterator, Mapping, MutableMapping, Sequence
from urllib.parse import unquote, urljoin, urlsplit
import configupdater
@@ -61,9 +61,10 @@ from flask import (
url_for,
)
from flask_appbuilder import BaseView, ModelView, expose
+from flask_appbuilder._compat import as_unicode
from flask_appbuilder.actions import action
+from flask_appbuilder.const import FLAMSG_ERR_SEC_ACCESS_DENIED
from flask_appbuilder.models.sqla.filters import BaseFilter
-from flask_appbuilder.security.decorators import has_access
from flask_appbuilder.urltools import get_order_args, get_page_args, get_page_size_args
from flask_appbuilder.widgets import FormWidget
from flask_babel import lazy_gettext
@@ -132,6 +133,7 @@ from airflow.utils.task_group import TaskGroup, task_group_to_dict
from airflow.utils.timezone import td_format, utcnow
from airflow.version import version
from airflow.www import auth, utils as wwwutils
+from airflow.www.auth import has_access_with_pk
from airflow.www.decorators import action_logging, gzipped
from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.forms import (
@@ -3955,71 +3957,92 @@ class AirflowModelView(ModelView):
return action_logging(event=f"{self.route_base.strip('/')}.{permission_str}")(attribute)
return attribute
+ @expose("/show/<pk>", methods=["GET"])
+ @has_access_with_pk
+ def show(self, pk):
+ """
+ Show view.
-class AirflowPrivilegeVerifierModelView(AirflowModelView):
- """
- Prevents ability to pass primary keys of objects relating to DAGs you shouldn't be able to edit.
-
- This only holds for the add, update and delete operations.
- You will still need to use the `action_has_dag_edit_access()` for actions.
- """
-
- @staticmethod
- def validate_dag_edit_access(item: DagRun | TaskInstance):
- """Validate whether the user has 'can_edit' access for this specific DAG."""
- if not get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=item.dag_id)):
- raise AirflowException(f"Access denied for dag_id {item.dag_id}")
-
- def pre_add(self, item: DagRun | TaskInstance):
- self.validate_dag_edit_access(item)
-
- def pre_update(self, item: DagRun | TaskInstance):
- self.validate_dag_edit_access(item)
+ Same implementation as
+ https://github.com/dpgaspar/Flask-AppBuilder/blob/1c3af9b665ed9a3daf36673fee3327d0abf43e5b/flask_appbuilder/views.py#L566
- def pre_delete(self, item: DagRun | TaskInstance):
- self.validate_dag_edit_access(item)
+ Override it to use a custom ``has_access_with_pk`` decorator to take into consideration resource for
+ fined-grained access.
+ """
+ pk = self._deserialize_pk_if_composite(pk)
+ widgets = self._show(pk)
+ return self.render_template(
+ self.show_template,
+ pk=pk,
+ title=self.show_title,
+ widgets=widgets,
+ related_views=self._related_views,
+ )
- def post_add_redirect(self): # Required to prevent redirect loop
- return redirect(self.get_default_url())
+ @expose("/edit/<pk>", methods=["GET", "POST"])
+ @has_access_with_pk
+ def edit(self, pk):
+ """
+ Edit view.
- def post_edit_redirect(self): # Required to prevent redirect loop
- return redirect(self.get_default_url())
+ Same implementation as
+ https://github.com/dpgaspar/Flask-AppBuilder/blob/1c3af9b665ed9a3daf36673fee3327d0abf43e5b/flask_appbuilder/views.py#L602
- def post_delete_redirect(self): # Required to prevent redirect loop
- return redirect(self.get_default_url())
+ Override it to use a custom ``has_access_with_pk`` decorator to take into consideration resource for
+ fined-grained access.
+ """
+ pk = self._deserialize_pk_if_composite(pk)
+ widgets = self._edit(pk)
+ if not widgets:
+ return self.post_edit_redirect()
+ else:
+ return self.render_template(
+ self.edit_template,
+ title=self.edit_title,
+ widgets=widgets,
+ related_views=self._related_views,
+ )
+ @expose("/delete/<pk>", methods=["GET", "POST"])
+ @has_access_with_pk
+ def delete(self, pk):
+ """
+ Delete view.
-def action_has_dag_edit_access(action_func: Callable) -> Callable:
- """Verify you have DAG edit access on the given tis/drs."""
+ Same implementation as
+ https://github.com/dpgaspar/Flask-AppBuilder/blob/1c3af9b665ed9a3daf36673fee3327d0abf43e5b/flask_appbuilder/views.py#L623
- @wraps(action_func)
- def check_dag_edit_acl_for_actions(
- self,
- items: list[TaskInstance] | list[DagRun] | TaskInstance | DagRun | None,
- *args,
- **kwargs,
- ) -> Callable:
- if items is None:
- dag_ids: set[str] = set()
- elif isinstance(items, list):
- dag_ids = {item.dag_id for item in items if item is not None}
- elif isinstance(items, (TaskInstance, DagRun)):
- dag_ids = {items.dag_id}
- else:
- raise ValueError(
- "Was expecting the first argument of the action to be of type "
- "list[TaskInstance] | list[DagRun] | TaskInstance | DagRun | None."
- f"Was of type: {type(items)}"
- )
+ Override it to use a custom ``has_access_with_pk`` decorator to take into consideration resource for
+ fined-grained access.
+ """
+ # Maintains compatibility but refuses to delete on GET methods if CSRF is enabled
+ if not self.is_get_mutation_allowed():
+ self.update_redirect()
+ logging.warning("CSRF is enabled and a delete using GET was invoked")
+ flash(as_unicode(FLAMSG_ERR_SEC_ACCESS_DENIED), "danger")
+ return self.post_delete_redirect()
+ pk = self._deserialize_pk_if_composite(pk)
+ self._delete(pk)
+ return self.post_delete_redirect()
+
+ @expose("/action_post", methods=["POST"])
+ def action_post(self):
+ """
+ Action method to handle multiple records selected from a list view.
- for dag_id in dag_ids:
- if not get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag_id)):
- flash(f"Access denied for dag_id {dag_id}", "danger")
- logging.warning("User %s tried to modify %s without having access.", g.user.username, dag_id)
- return redirect(self.get_default_url())
- return action_func(self, items, *args, **kwargs)
+ Same implementation as
+ https://github.com/dpgaspar/Flask-AppBuilder/blob/2c5763371b81cd679d88b9971ba5d1fc4d71d54b/flask_appbuilder/views.py#L677
- return check_dag_edit_acl_for_actions
+ The difference is, it no longer check permissions with ``self.appbuilder.sm.has_access``,
+ it executes the function without verifying permissions.
+ Thus, each action need to be annotated individually with ``@auth.has_access_*`` to check user
+ permissions.
+ """
+ name = request.form["action"]
+ pks = request.form.getlist("rowid")
+ action = self.actions.get(name)
+ items = [self.datamodel.get(self._deserialize_pk_if_composite(pk)) for pk in pks]
+ return action.func(items)
class SlaMissModelView(AirflowModelView):
@@ -4065,6 +4088,7 @@ class SlaMissModelView(AirflowModelView):
}
@action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False)
+ @auth.has_access_dag_entities("DELETE", DagAccessEntity.SLA_MISS)
def action_muldelete(self, items):
"""Multiple delete action."""
self.datamodel.delete_all(items)
@@ -4077,6 +4101,7 @@ class SlaMissModelView(AirflowModelView):
"Are you sure you want to set all these notifications to sent?",
single=False,
)
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.SLA_MISS)
def action_mulnotificationsent(self, items: list[SlaMiss]):
return self._set_notification_property(items, "notification_sent", True)
@@ -4086,6 +4111,7 @@ class SlaMissModelView(AirflowModelView):
"Are you sure you want to mark these SLA alerts as notification not sent yet?",
single=False,
)
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.SLA_MISS)
def action_mulnotificationsentfalse(self, items: list[SlaMiss]):
return self._set_notification_property(items, "notification_sent", False)
@@ -4095,6 +4121,7 @@ class SlaMissModelView(AirflowModelView):
"Are you sure you want to mark these SLA alerts as emails were sent?",
single=False,
)
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.SLA_MISS)
def action_mulemailsent(self, items: list[SlaMiss]):
return self._set_notification_property(items, "email_sent", True)
@@ -4104,6 +4131,7 @@ class SlaMissModelView(AirflowModelView):
"Are you sure you want to mark these SLA alerts as emails not sent yet?",
single=False,
)
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.SLA_MISS)
def action_mulemailsentfalse(self, items: list[SlaMiss]):
return self._set_notification_property(items, "email_sent", False)
@@ -4146,7 +4174,6 @@ class XComModelView(AirflowModelView):
"action_muldelete": "delete",
}
base_permissions = [
- permissions.ACTION_CAN_CREATE,
permissions.ACTION_CAN_READ,
permissions.ACTION_CAN_DELETE,
permissions.ACTION_CAN_ACCESS_MENU,
@@ -4167,6 +4194,7 @@ class XComModelView(AirflowModelView):
}
@action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False)
+ @auth.has_access_dag_entities("DELETE", DagAccessEntity.XCOM)
def action_muldelete(self, items):
"""Multiple delete action."""
self.datamodel.delete_all(items)
@@ -4662,6 +4690,7 @@ class PoolModelView(AirflowModelView):
base_order = ("pool", "asc")
@action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False)
+ @auth.has_access_pool("DELETE")
def action_muldelete(self, items):
"""Multiple delete."""
if any(item.pool == models.Pool.DEFAULT_POOL_NAME for item in items):
@@ -4673,7 +4702,7 @@ class PoolModelView(AirflowModelView):
return redirect(self.get_redirect())
@expose("/delete/<pk>", methods=["GET", "POST"])
- @has_access
+ @has_access_with_pk
def delete(self, pk):
"""Single delete."""
if models.Pool.is_default_pool(pk):
@@ -4747,12 +4776,6 @@ class PoolModelView(AirflowModelView):
validators_columns = {"pool": [validators.DataRequired()], "slots": [validators.NumberRange(min=-1)]}
-def _can_create_variable() -> bool:
- return get_airflow_app().appbuilder.sm.has_access(
- permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE
- )
-
-
class VariableModelView(AirflowModelView):
"""View to show records from Variable table."""
@@ -4833,9 +4856,10 @@ class VariableModelView(AirflowModelView):
item, orders=orders, pages=pages, page_sizes=page_sizes, widgets=widgets
)
- extra_args = {"can_create_variable": _can_create_variable}
+ extra_args = {"can_create_variable": lambda: get_auth_manager().is_authorized_variable(method="POST")}
@action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False)
+ @auth.has_access_variable("DELETE")
def action_muldelete(self, items):
"""Multiple delete."""
self.datamodel.delete_all(items)
@@ -4843,6 +4867,7 @@ class VariableModelView(AirflowModelView):
return redirect(self.get_redirect())
@action("varexport", "Export", "", single=False)
+ @auth.has_access_variable("GET")
def action_varexport(self, items):
"""Export variables."""
var_dict = {}
@@ -4966,7 +4991,7 @@ class JobModelView(AirflowModelView):
}
-class DagRunModelView(AirflowPrivilegeVerifierModelView):
+class DagRunModelView(AirflowModelView):
"""View to show records from DagRun table."""
route_base = "/dagrun"
@@ -5078,7 +5103,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
}
@action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("DELETE", DagAccessEntity.RUN)
@action_logging
def action_muldelete(self, items: list[DagRun]):
"""Multiple delete."""
@@ -5087,14 +5112,14 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@action("set_queued", "Set state to 'queued'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.RUN)
@action_logging
def action_set_queued(self, drs: list[DagRun]):
"""Set state to queued."""
return self._set_dag_runs_to_active_state(drs, DagRunState.QUEUED)
@action("set_running", "Set state to 'running'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.RUN)
@action_logging
def action_set_running(self, drs: list[DagRun]):
"""Set state to running."""
@@ -5128,7 +5153,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
"All running task instances would also be marked as failed, are you sure?",
single=False,
)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.RUN)
@provide_session
@action_logging
def action_set_failed(self, drs: list[DagRun], session: Session = NEW_SESSION):
@@ -5156,7 +5181,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
"All task instances would also be marked as success, are you sure?",
single=False,
)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.RUN)
@provide_session
@action_logging
def action_set_success(self, drs: list[DagRun], session: Session = NEW_SESSION):
@@ -5179,7 +5204,7 @@ class DagRunModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_default_url())
@action("clear", "Clear the state", "All task instances would be cleared, are you sure?", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.RUN)
@provide_session
@action_logging
def action_clear(self, drs: list[DagRun], session: Session = NEW_SESSION):
@@ -5367,7 +5392,7 @@ class TriggerModelView(AirflowModelView):
}
-class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
+class TaskInstanceModelView(AirflowModelView):
"""View to show records from TaskInstance table."""
route_base = "/taskinstance"
@@ -5587,7 +5612,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
),
single=False,
)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@provide_session
@action_logging
def action_clear(self, task_instances, session: Session = NEW_SESSION):
@@ -5613,7 +5638,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
),
single=False,
)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@provide_session
@action_logging
def action_clear_downstream(self, task_instances, session: Session = NEW_SESSION):
@@ -5634,7 +5659,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@action("muldelete", "Delete", "Are you sure you want to delete selected records?", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("DELETE", DagAccessEntity.TASK_INSTANCE)
@action_logging
def action_muldelete(self, items):
self.datamodel.delete_all(items)
@@ -5659,7 +5684,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
flash("Failed to set state", "error")
@action("set_running", "Set state to 'running'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
def action_set_running(self, tis):
"""Set state to 'running'."""
@@ -5668,7 +5693,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@action("set_failed", "Set state to 'failed'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
def action_set_failed(self, tis):
"""Set state to 'failed'."""
@@ -5677,7 +5702,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@action("set_success", "Set state to 'success'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
def action_set_success(self, tis):
"""Set state to 'success'."""
@@ -5686,7 +5711,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@action("set_retry", "Set state to 'up_for_retry'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
def action_set_retry(self, tis):
"""Set state to 'up_for_retry'."""
@@ -5695,7 +5720,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
return redirect(self.get_redirect())
@action("set_skipped", "Set state to 'skipped'", "", single=False)
- @action_has_dag_edit_access
+ @auth.has_access_dag_entities("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
def action_set_skipped(self, tis):
"""Set state to skipped."""
@@ -5821,8 +5846,8 @@ def add_user_permissions_to_dag(sender, template, context, **extra):
if "dag" not in context:
return
dag = context["dag"]
- can_create_dag_run = get_airflow_app().appbuilder.sm.has_access(
- permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN
+ can_create_dag_run = get_auth_manager().is_authorized_dag(
+ method="POST", access_entity=DagAccessEntity.RUN
)
dag.can_edit = get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag.dag_id))
diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py b/tests/auth/managers/fab/test_fab_auth_manager.py
index f7cb7070d7..934293606a 100644
--- a/tests/auth/managers/fab/test_fab_auth_manager.py
+++ b/tests/auth/managers/fab/test_fab_auth_manager.py
@@ -43,7 +43,6 @@ from airflow.security.permissions import (
RESOURCE_JOB,
RESOURCE_PLUGIN,
RESOURCE_PROVIDER,
- RESOURCE_SLA_MISS,
RESOURCE_TASK_INSTANCE,
RESOURCE_TRIGGER,
RESOURCE_VARIABLE,
@@ -362,12 +361,6 @@ class TestFabAuthManager:
[(ACTION_CAN_READ, RESOURCE_TRIGGER)],
True,
),
- # With permission (SLA)
- (
- AccessView.SLA,
- [(ACTION_CAN_READ, RESOURCE_SLA_MISS)],
- True,
- ),
# With permission (website)
(
AccessView.WEBSITE,
diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py
index 05a92e0c7e..83fb0eebdf 100644
--- a/tests/www/test_auth.py
+++ b/tests/www/test_auth.py
@@ -17,79 +17,252 @@
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import Mock, patch
import pytest
+import airflow.www.auth as auth
+from airflow.auth.managers.models.resource_details import DagAccessEntity
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.models import Connection, Pool, Variable
from airflow.security import permissions
from airflow.settings import json
+from airflow.www.auth import has_access
from tests.test_utils.api_connexion_utils import create_user_scope
-from tests.www.test_security import SomeBaseView, SomeModelView
-pytestmark = pytest.mark.db_test
+mock_call = Mock()
-@pytest.fixture(scope="module")
-def app_builder(app):
- app_builder = app.appbuilder
- app_builder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
- app_builder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
- return app.appbuilder
+class TestHasAccessDecorator:
+ def test_has_access_decorator_raises_deprecation_warning(self):
+ with pytest.warns(RemovedInAirflow3Warning):
+
+ @has_access
+ def test_function():
+ pass
+
+
+@pytest.mark.parametrize(
+ "decorator_name, is_authorized_method_name",
+ [
+ ("has_access_cluster_activity", "is_authorized_cluster_activity"),
+ ("has_access_configuration", "is_authorized_configuration"),
+ ("has_access_dataset", "is_authorized_dataset"),
+ ("has_access_view", "is_authorized_view"),
+ ],
+)
+class TestHasAccessNoDetails:
+ def setup_method(self):
+ mock_call.reset_mock()
+
+ def method_test(self):
+ mock_call()
+ return True
+
+ @patch("airflow.www.auth.get_auth_manager")
+ def test_has_access_no_details_when_authorized(
+ self, mock_get_auth_manager, decorator_name, is_authorized_method_name
+ ):
+ auth_manager = Mock()
+ is_authorized_method = Mock()
+ is_authorized_method.return_value = True
+ setattr(auth_manager, is_authorized_method_name, is_authorized_method)
+ mock_get_auth_manager.return_value = auth_manager
+
+ result = getattr(auth, decorator_name)("GET")(self.method_test)()
+
+ mock_call.assert_called_once()
+ assert result is True
+
+ @patch("airflow.www.auth.get_auth_manager")
+ @patch("airflow.www.auth.render_template")
+ def test_has_access_no_details_when_no_permission(
+ self, mock_render_template, mock_get_auth_manager, decorator_name, is_authorized_method_name
+ ):
+ auth_manager = Mock()
+ is_authorized_method = Mock()
+ is_authorized_method.return_value = False
+ setattr(auth_manager, is_authorized_method_name, is_authorized_method)
+ auth_manager.is_logged_in.return_value = True
+ auth_manager.is_authorized_view.return_value = False
+ mock_get_auth_manager.return_value = auth_manager
+
+ getattr(auth, decorator_name)("GET")(self.method_test)()
+
+ mock_call.assert_not_called()
+ mock_render_template.assert_called_once()
+
+ @pytest.mark.db_test
+ @patch("airflow.www.auth.get_auth_manager")
+ def test_has_access_no_details_when_not_logged_in(
+ self, mock_get_auth_manager, app, decorator_name, is_authorized_method_name
+ ):
+ auth_manager = Mock()
+ is_authorized_method = Mock()
+ is_authorized_method.return_value = False
+ setattr(auth_manager, is_authorized_method_name, is_authorized_method)
+ auth_manager.is_logged_in.return_value = False
+ auth_manager.get_url_login.return_value = "login_url"
+ mock_get_auth_manager.return_value = auth_manager
+
+ with app.test_request_context():
+ result = getattr(auth, decorator_name)("GET")(self.method_test)()
+
+ mock_call.assert_not_called()
+ assert result.status_code == 302
@pytest.mark.parametrize(
- "dag_id_args, dag_id_kwargs, dag_id_form, dag_id_json, fail",
+ "decorator_name, is_authorized_method_name, items",
[
- ("a", None, None, None, False),
- (None, "b", None, None, False),
- (None, None, "c", None, False),
- (None, None, None, "d", False),
- ("a", "a", None, None, False),
- ("a", "a", "a", None, False),
- ("a", "a", "a", "a", False),
- (None, "a", "a", "a", False),
- (None, None, "a", "a", False),
- ("a", None, None, "a", False),
- ("a", None, "a", None, False),
- ("a", None, "c", None, True),
- (None, "b", "c", None, True),
- (None, None, "c", "d", True),
- ("a", "b", "c", "d", True),
+ ("has_access_connection", "is_authorized_connection", [Connection("conn_1"), Connection("conn_2")]),
+ ("has_access_pool", "is_authorized_pool", [Pool(pool="pool_1"), Pool(pool="pool_2")]),
+ ("has_access_variable", "is_authorized_variable", [Variable("var_1"), Variable("var_2")]),
],
)
-def test_dag_id_consistency(
- app,
- dag_id_args: str | None,
- dag_id_kwargs: str | None,
- dag_id_form: str | None,
- dag_id_json: str | None,
- fail: bool,
-):
- with app.test_request_context() as mock_context:
- from airflow.www.auth import has_access_dag
-
- mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {}
- kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {}
- mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form else {}
- if dag_id_json:
- mock_context.request._cached_data = json.dumps({"dag_id": dag_id_json})
- mock_context.request._parsed_content_type = ["application/json"]
-
- with create_user_scope(
- app,
- username="test-user",
- role_name="limited-role",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)],
- ) as user:
- with patch("airflow.auth.managers.fab.fab_auth_manager.FabAuthManager.get_user") as mock_get_user:
- mock_get_user.return_value = user
-
- @has_access_dag("GET")
- def test_func(**kwargs):
- return True
-
- result = test_func(**kwargs)
- if fail:
- assert result[1] == 403
- else:
- assert result is True
+class TestHasAccessWithDetails:
+ def setup_method(self):
+ mock_call.reset_mock()
+
+ def method_test(self, _view, arg):
+ mock_call()
+ return True
+
+ @patch("airflow.www.auth.get_auth_manager")
+ def test_has_access_with_details_when_authorized(
+ self, mock_get_auth_manager, decorator_name, is_authorized_method_name, items
+ ):
+ auth_manager = Mock()
+ is_authorized_method = Mock()
+ is_authorized_method.return_value = True
+ setattr(auth_manager, is_authorized_method_name, is_authorized_method)
+ mock_get_auth_manager.return_value = auth_manager
+
+ result = getattr(auth, decorator_name)("GET")(self.method_test)(None, items)
+
+ mock_call.assert_called_once()
+ assert result is True
+
+ @pytest.mark.db_test
+ @patch("airflow.www.auth.get_auth_manager")
+ def test_has_access_with_details_when_unauthorized(
+ self, mock_get_auth_manager, app, decorator_name, is_authorized_method_name, items
+ ):
+ auth_manager = Mock()
+ is_authorized_method = Mock()
+ is_authorized_method.return_value = False
+ setattr(auth_manager, is_authorized_method_name, is_authorized_method)
+ mock_get_auth_manager.return_value = auth_manager
+
+ with app.test_request_context():
+ result = getattr(auth, decorator_name)("GET")(self.method_test)(None, items)
+
+ mock_call.assert_not_called()
+ assert result.status_code == 302
+
+
+@pytest.mark.parametrize(
+ "dag_access_entity",
+ [
+ DagAccessEntity.SLA_MISS,
+ DagAccessEntity.XCOM,
+ DagAccessEntity.RUN,
+ DagAccessEntity.TASK_INSTANCE,
+ ],
+)
+class TestHasAccessDagEntities:
+ def setup_method(self):
+ mock_call.reset_mock()
+
+ def method_test(self, _view, arg):
+ mock_call()
+ return True
+
+ @patch("airflow.www.auth.get_auth_manager")
+ def test_has_access_dag_entities_when_authorized(self, mock_get_auth_manager, dag_access_entity):
+ auth_manager = Mock()
+ auth_manager.is_authorized_dag.return_value = True
+ mock_get_auth_manager.return_value = auth_manager
+ items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")]
+
+ result = auth.has_access_dag_entities("GET", dag_access_entity)(self.method_test)(None, items)
+
+ mock_call.assert_called_once()
+ assert result is True
+
+ @pytest.mark.db_test
+ @patch("airflow.www.auth.get_auth_manager")
+ def test_has_access_dag_entities_when_unauthorized(self, mock_get_auth_manager, app, dag_access_entity):
+ auth_manager = Mock()
+ auth_manager.is_authorized_dag.return_value = False
+ mock_get_auth_manager.return_value = auth_manager
+ items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")]
+
+ with app.test_request_context():
+ result = auth.has_access_dag_entities("GET", dag_access_entity)(self.method_test)(None, items)
+
+ mock_call.assert_not_called()
+ assert result.status_code == 302
+
+
+@pytest.mark.db_test
+class TestHasAccessDagDecorator:
+ @pytest.mark.parametrize(
+ "dag_id_args, dag_id_kwargs, dag_id_form, dag_id_json, fail",
+ [
+ ("a", None, None, None, False),
+ (None, "b", None, None, False),
+ (None, None, "c", None, False),
+ (None, None, None, "d", False),
+ ("a", "a", None, None, False),
+ ("a", "a", "a", None, False),
+ ("a", "a", "a", "a", False),
+ (None, "a", "a", "a", False),
+ (None, None, "a", "a", False),
+ ("a", None, None, "a", False),
+ ("a", None, "a", None, False),
+ ("a", None, "c", None, True),
+ (None, "b", "c", None, True),
+ (None, None, "c", "d", True),
+ ("a", "b", "c", "d", True),
+ ],
+ )
+ def test_dag_id_consistency(
+ self,
+ app,
+ dag_id_args: str | None,
+ dag_id_kwargs: str | None,
+ dag_id_form: str | None,
+ dag_id_json: str | None,
+ fail: bool,
+ ):
+ with app.test_request_context() as mock_context:
+ from airflow.www.auth import has_access_dag
+
+ mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {}
+ kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {}
+ mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form else {}
+ if dag_id_json:
+ mock_context.request._cached_data = json.dumps({"dag_id": dag_id_json})
+ mock_context.request._parsed_content_type = ["application/json"]
+
+ with create_user_scope(
+ app,
+ username="test-user",
+ role_name="limited-role",
+ permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)],
+ ) as user:
+ with patch(
+ "airflow.auth.managers.fab.fab_auth_manager.FabAuthManager.get_user"
+ ) as mock_get_user:
+ mock_get_user.return_value = user
+
+ @has_access_dag("GET")
+ def test_func(**kwargs):
+ return True
+
+ result = test_func(**kwargs)
+ if fail:
+ assert result[1] == 403
+ else:
+ assert result is True
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index 532064bb82..c49ce88f51 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import os
import re
-from typing import Callable
from unittest import mock
import pytest
@@ -31,7 +30,6 @@ from airflow.configuration import (
)
from airflow.plugins_manager import AirflowPlugin, EntryPointSource
from airflow.utils.task_group import TaskGroup
-from airflow.www import views
from airflow.www.views import (
get_key_paths,
get_safe_url,
@@ -456,34 +454,6 @@ def test_get_value_from_path(test_content_dict, test_key_path, expected_value):
assert expected_value == get_value_from_path(test_key_path, test_content_dict)
-def assert_decorator_used(cls: type, fn_name: str, decorator: Callable):
- fn = getattr(cls, fn_name)
- code = decorator(None).__code__
- while fn is not None:
- if fn.__code__ is code:
- return
- if not hasattr(fn, "__wrapped__"):
- break
- fn = getattr(fn, "__wrapped__")
- assert False, f"{cls.__name__}.{fn_name} was not decorated with @{decorator.__name__}"
-
-
-@pytest.mark.parametrize(
- "cls",
- [
- views.TaskInstanceModelView,
- views.DagRunModelView,
- ],
-)
-def test_dag_edit_privileged_requires_view_has_action_decorators(cls: type):
- action_funcs = {func for func in dir(cls) if callable(getattr(cls, func)) and func.startswith("action_")}
-
- # We remove action_post as this is a standard SQLAlchemy function no enable other action functions.
- action_funcs = action_funcs - {"action_post"}
- for action_function in action_funcs:
- assert_decorator_used(cls, action_function, views.action_has_dag_edit_access)
-
-
def test_get_task_stats_from_query():
query_data = [
["dag1", "queued", True, 1],
diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py
index eccf73c6b5..f55a6f16ca 100644
--- a/tests/www/views/test_views_dagrun.py
+++ b/tests/www/views/test_views_dagrun.py
@@ -17,8 +17,6 @@
# under the License.
from __future__ import annotations
-import flask
-import markupsafe
import pytest
from airflow.models import DagBag, DagRun, TaskInstance
@@ -40,6 +38,7 @@ def client_dr_without_dag_edit(app):
username="all_dr_permissions_except_dag_edit",
role_name="all_dr_permissions_except_dag_edit",
permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
@@ -108,15 +107,7 @@ def test_get_dagrun_can_view_dags_without_edit_perms(session, running_dag_run, c
"""Test that a user without dag_edit but with dag_read permission can view the records"""
assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1
resp = client_dr_without_dag_edit.get("/dagrun/list/", follow_redirects=True)
-
- with client_dr_without_dag_edit.application.test_request_context():
- url = flask.url_for(
- "Airflow.graph", dag_id=running_dag_run.dag_id, execution_date=running_dag_run.execution_date
- )
- dag_url_link = markupsafe.Markup('<a href="{url}">{dag_id}</a>').format(
- url=url, dag_id=running_dag_run.dag_id
- )
- check_content_in_response(dag_url_link, resp)
+ check_content_in_response(running_dag_run.dag_id, resp)
def test_create_dagrun_permission_denied(session, client_dr_without_dag_run_create):
@@ -190,7 +181,7 @@ def test_delete_dagrun_permission_denied(session, running_dag_run, client_dr_wit
assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1
resp = client_dr_without_dag_edit.post(f"/dagrun/delete/{composite_key}", follow_redirects=True)
- check_content_in_response(f"Access denied for dag_id {running_dag_run.dag_id}", resp)
+ check_content_in_response("Access is Denied", resp)
assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1
@@ -286,7 +277,7 @@ def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, runni
data={"action": action, "rowid": [str(running_dag_id)]},
follow_redirects=True,
)
- check_content_in_response(f"Access denied for dag_id {running_dag_run.dag_id}", resp)
+ check_content_in_response("Access is Denied", resp)
def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_run_with_missing_task):
diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py
index 35ba2a968e..00e657a2d7 100644
--- a/tests/www/views/test_views_decorators.py
+++ b/tests/www/views/test_views_decorators.py
@@ -25,7 +25,6 @@ from airflow.models import DagBag, Variable
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
-from airflow.www.views import action_has_dag_edit_access
from tests.test_utils.db import clear_db_runs, clear_db_variables
from tests.test_utils.www import _check_last_log, _check_last_log_masked_variable, check_content_in_response
@@ -93,11 +92,6 @@ def clean_db():
clear_db_variables()
-@action_has_dag_edit_access
-def some_view_action_which_requires_dag_edit_access(*args) -> bool:
- return True
-
-
def test_action_logging_get(session, admin_client):
url = (
f"dags/example_bash_operator/grid?"
@@ -185,8 +179,3 @@ def test_calendar(admin_client, dagruns):
datestr = bash_dagrun.execution_date.date().isoformat()
expected = rf"{{\"date\":\"{datestr}\",\"state\":\"running\",\"count\":1}}"
check_content_in_response(expected, resp)
-
-
-def test_action_has_dag_edit_access_exception():
- with pytest.raises(ValueError):
- some_view_action_which_requires_dag_edit_access(None, "some_incorrect_value")
diff --git a/tests/www/views/test_views_pool.py b/tests/www/views/test_views_pool.py
index 4aa3c121bf..783f6de125 100644
--- a/tests/www/views/test_views_pool.py
+++ b/tests/www/views/test_views_pool.py
@@ -129,3 +129,14 @@ def test_pool_muldelete_default(session, admin_client, pool_factory):
)
check_content_in_response("default_pool cannot be deleted", resp)
assert session.query(Pool).filter(Pool.id == pool.id).count() == 1
+
+
+def test_pool_muldelete_access_denied(session, viewer_client, pool_factory):
+ pool = pool_factory()
+
+ resp = viewer_client.post(
+ "/pool/action_post",
+ data={"action": "muldelete", "rowid": [pool.id]},
+ follow_redirects=True,
+ )
+ check_content_in_response("Access is Denied", resp)
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index c22ea5e2e7..55568d4d8f 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -832,7 +832,7 @@ def test_task_instance_delete_permission_denied(session, client_ti_without_dag_e
assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1
resp = client_ti_without_dag_edit.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True)
- check_content_in_response(f"Access denied for dag_id {task_instance_to_delete.dag_id}", resp)
+ check_content_in_response("Access is Denied", resp)
assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1
@@ -862,7 +862,9 @@ def test_task_instance_clear(session, request, client_fixture, should_succeed):
data={"action": "clear", "rowid": rowid},
follow_redirects=True,
)
- assert resp.status_code == (200 if should_succeed else 404)
+ assert resp.status_code == 200
+ if not should_succeed and client_fixture != "anonymous_client":
+ check_content_in_response("Access is Denied", resp)
# Now the state should be None.
state = session.query(TaskInstance.state).filter(TaskInstance.task_id == task_id).scalar()
diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py
index 76f5b91db4..0814b0aad1 100644
--- a/tests/www/views/test_views_variable.py
+++ b/tests/www/views/test_views_variable.py
@@ -55,7 +55,10 @@ def user_variable_reader(app):
app,
username="user_variable_reader",
role_name="role_variable_reader",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)],
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),
+ ],
)
@@ -190,6 +193,16 @@ def test_import_variables_anon(session, app):
check_content_in_response("Sign In", resp)
+def test_import_variables_access_denied(session, app, viewer_client):
+ content = '{"str_key": "str_value}'
+ bytes_content = BytesIO(bytes(content, encoding="utf-8"))
+
+ resp = viewer_client.post(
+ "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True
+ )
+ check_content_in_response("Access is Denied", resp)
+
+
def test_import_variables_form_shown(app, admin_client):
resp = admin_client.get("/variable/list/")
check_content_in_response("Import Variables", resp)
@@ -242,3 +255,13 @@ def test_action_muldelete(session, admin_client, variable):
)
assert resp.status_code == 200
assert session.query(Variable).filter(Variable.id == var_id).count() == 0
+
+
+def test_action_muldelete_access_denied(session, client_variable_reader, variable):
+ var_id = variable.id
+ resp = client_variable_reader.post(
+ "/variable/action_post",
+ data={"action": "muldelete", "rowid": [var_id]},
+ follow_redirects=True,
+ )
+ check_content_in_response("Access is Denied", resp)