You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/06/05 15:44:40 UTC
[incubator-superset] branch master updated: style(mypy): Enforcing
typing for superset.views (#9939)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 63e0188 style(mypy): Enforcing typing for superset.views (#9939)
63e0188 is described below
commit 63e0188f45134c25267d183f5d7391577f9a6d63
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Fri Jun 5 08:44:11 2020 -0700
style(mypy): Enforcing typing for superset.views (#9939)
Co-authored-by: John Bodley <jo...@airbnb.com>
---
setup.cfg | 2 +-
superset/connectors/base/models.py | 18 +-
superset/connectors/druid/models.py | 2 +-
superset/connectors/sqla/models.py | 2 +-
superset/sql_validators/base.py | 6 +-
superset/sql_validators/presto_db.py | 2 +-
superset/tasks/schedules.py | 6 +-
superset/views/annotations.py | 8 +-
superset/views/api.py | 7 +-
superset/views/base.py | 56 +++--
superset/views/base_api.py | 34 ++-
superset/views/base_schemas.py | 25 +-
superset/views/core.py | 443 ++++++++++++++++++-----------------
superset/views/datasource.py | 11 +-
superset/views/filters.py | 7 +-
superset/views/log/__init__.py | 4 +-
superset/views/log/api.py | 2 +-
superset/views/schedules.py | 35 ++-
superset/views/sql_lab.py | 43 ++--
superset/views/tags.py | 23 +-
superset/views/utils.py | 26 +-
superset/viz.py | 2 +-
superset/viz_sip38.py | 16 +-
23 files changed, 440 insertions(+), 340 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index fc94a24..81c7ed2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
-[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...]
+[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 8ead670..0533aa1 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -62,14 +62,26 @@ class BaseDatasource(
# ---------------------------------------------------------------
__tablename__: Optional[str] = None # {connector_name}_datasource
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
- column_class: Optional[Type] = None # link to derivative of BaseColumn
- metric_class: Optional[Type] = None # link to derivative of BaseMetric
+
+ @property
+ def column_class(self) -> Type:
+ # link to derivative of BaseColumn
+ raise NotImplementedError()
+
+ @property
+ def metric_class(self) -> Type:
+ # link to derivative of BaseMetric
+ raise NotImplementedError()
+
owner_class: Optional[User] = None
# Used to do code highlighting when displaying the query in the UI
query_language: Optional[str] = None
- name = None # can be a Column or a property pointing to one
+ @property
+ def name(self) -> str:
+ # can be a Column or a property pointing to one
+ raise NotImplementedError()
# ---------------------------------------------------------------
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index b0a3332..50f1637 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -548,7 +548,7 @@ class DruidDatasource(Model, BaseDatasource):
return [c.column_name for c in self.columns if c.is_numeric]
@property
- def name(self) -> str: # type: ignore
+ def name(self) -> str:
return self.datasource_name
@property
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index b413ebd..0e91bd2 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -531,7 +531,7 @@ class SqlaTable(Model, BaseDatasource):
return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self)
@property
- def name(self) -> str: # type: ignore
+ def name(self) -> str:
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)
diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py
index beed47c..c477568 100644
--- a/superset/sql_validators/base.py
+++ b/superset/sql_validators/base.py
@@ -19,6 +19,8 @@
from typing import Any, Dict, List, Optional
+from superset.models.core import Database
+
class SQLValidationAnnotation:
"""Represents a single annotation (error/warning) in an SQL querytext"""
@@ -35,7 +37,7 @@ class SQLValidationAnnotation:
self.start_column = start_column
self.end_column = end_column
- def to_dict(self) -> Dict:
+ def to_dict(self) -> Dict[str, Any]:
"""Return a dictionary representation of this annotation"""
return {
"line_number": self.line_number,
@@ -53,7 +55,7 @@ class BaseSQLValidator:
@classmethod
def validate(
- cls, sql: str, schema: str, database: Any
+ cls, sql: str, schema: Optional[str], database: Database
) -> List[SQLValidationAnnotation]:
"""Check that the given SQL querystring is valid for the given engine"""
raise NotImplementedError
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 42e7cff..6c5bb30 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -143,7 +143,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate(
- cls, sql: str, schema: str, database: Any
+ cls, sql: str, schema: Optional[str], database: Database
) -> List[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 0a356ea..2a5733e 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -225,9 +225,11 @@ def deliver_dashboard(schedule: DashboardEmailSchedule) -> None:
"""
dashboard = schedule.dashboard
- dashboard_url = _get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
+ dashboard_url = _get_url_path(
+ "Superset.dashboard", dashboard_id_or_slug=dashboard.id
+ )
dashboard_url_user_friendly = _get_url_path(
- "Superset.dashboard", user_friendly=True, dashboard_id=dashboard.id
+ "Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id
)
# Create a driver, fetch the page, wait for the page to render
diff --git a/superset/views/annotations.py b/superset/views/annotations.py
index e29883d..8442877 100644
--- a/superset/views/annotations.py
+++ b/superset/views/annotations.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Any, Dict
+
from flask_appbuilder import CompactCRUDMixin
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
@@ -30,7 +32,7 @@ class StartEndDttmValidator: # pylint: disable=too-few-public-methods
Validates dttm fields.
"""
- def __call__(self, form, field):
+ def __call__(self, form: Dict[str, Any], field: Any) -> None:
if not form["start_dttm"].data and not form["end_dttm"].data:
raise StopValidation(_("annotation start time or end time is required."))
elif (
@@ -82,13 +84,13 @@ class AnnotationModelView(
validators_columns = {"start_dttm": [StartEndDttmValidator()]}
- def pre_add(self, item):
+ def pre_add(self, item: "AnnotationModelView") -> None:
if not item.start_dttm:
item.start_dttm = item.end_dttm
elif not item.end_dttm:
item.end_dttm = item.start_dttm
- def pre_update(self, item):
+ def pre_update(self, item: "AnnotationModelView") -> None:
self.pre_add(item)
diff --git a/superset/views/api.py b/superset/views/api.py
index e82aa86..d370598 100644
--- a/superset/views/api.py
+++ b/superset/views/api.py
@@ -24,6 +24,7 @@ from superset import db, event_logger, security_manager
from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
+from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import api, BaseSupersetView, handle_api_exception
@@ -34,13 +35,13 @@ class Api(BaseSupersetView):
@handle_api_exception
@has_access_api
@expose("/v1/query/", methods=["POST"])
- def query(self):
+ def query(self) -> FlaskResponse:
"""
Takes a query_obj constructed in the client and returns payload data response
for the given query_obj.
params: query_context: json_blob
"""
- query_context = QueryContext(**json.loads(request.form.get("query_context")))
+ query_context = QueryContext(**json.loads(request.form["query_context"]))
security_manager.assert_query_context_permission(query_context)
payload_json = query_context.get_payload()
return json.dumps(
@@ -52,7 +53,7 @@ class Api(BaseSupersetView):
@handle_api_exception
@has_access_api
@expose("/v1/form_data/", methods=["GET"])
- def query_form_data(self):
+ def query_form_data(self) -> FlaskResponse:
"""
Get the formdata stored in the database for existing slice.
params: slice_id: integer
diff --git a/superset/views/base.py b/superset/views/base.py
index 5821619..1238218 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -18,13 +18,13 @@ import functools
import logging
import traceback
from datetime import datetime
-from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
+from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union
import dataclasses
import simplejson as json
import yaml
from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
-from flask_appbuilder import BaseView, ModelView
+from flask_appbuilder import BaseView, Model, ModelView
from flask_appbuilder.actions import action
from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.filters import BaseFilter
@@ -33,7 +33,9 @@ from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_wtf.form import FlaskForm
from sqlalchemy import or_
+from sqlalchemy.orm import Query
from werkzeug.exceptions import HTTPException
+from wtforms import Form
from wtforms.fields.core import Field, UnboundField
from superset import (
@@ -47,6 +49,7 @@ from superset import (
from superset.connectors.sqla import models
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
+from superset.models.helpers import ImportMixin
from superset.translations.utils import get_language_pack
from superset.typing import FlaskResponse
from superset.utils import core as utils
@@ -93,7 +96,7 @@ def json_error_response(
status: int = 500,
payload: Optional[Dict[str, Any]] = None,
link: Optional[str] = None,
-) -> Response:
+) -> FlaskResponse:
if not payload:
payload = {"error": "{}".format(msg)}
if link:
@@ -110,7 +113,7 @@ def json_errors_response(
errors: List[SupersetError],
status: int = 500,
payload: Optional[Dict[str, Any]] = None,
-) -> Response:
+) -> FlaskResponse:
if not payload:
payload = {}
@@ -122,11 +125,11 @@ def json_errors_response(
)
-def json_success(json_msg: str, status: int = 200) -> Response:
+def json_success(json_msg: str, status: int = 200) -> FlaskResponse:
return Response(json_msg, status=status, mimetype="application/json")
-def data_payload_response(payload_json: str, has_error: bool = False) -> Response:
+def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskResponse:
status = 400 if has_error else 200
return json_success(payload_json, status=status)
@@ -140,13 +143,13 @@ def generate_download_headers(
return headers
-def api(f):
+def api(f: Callable) -> Callable:
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
return the response in the JSON format
"""
- def wraps(self, *args, **kwargs):
+ def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse:
try:
return f(self, *args, **kwargs)
except Exception as ex: # pylint: disable=broad-except
@@ -156,14 +159,16 @@ def api(f):
return functools.update_wrapper(wraps, f)
-def handle_api_exception(f):
+def handle_api_exception(
+ f: Callable[..., FlaskResponse]
+) -> Callable[..., FlaskResponse]:
"""
A decorator to catch superset exceptions. Use it after the @api decorator above
so superset exception handler is triggered before the handler for generic
exceptions.
"""
- def wraps(self, *args, **kwargs):
+ def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse:
try:
return f(self, *args, **kwargs)
except SupersetSecurityException as ex:
@@ -179,7 +184,7 @@ def handle_api_exception(f):
except HTTPException as ex:
logger.exception(ex)
return json_error_response(
- utils.error_msg_from_exception(ex), status=ex.code
+ utils.error_msg_from_exception(ex), status=cast(int, ex.code)
)
except Exception as ex: # pylint: disable=broad-except
logger.exception(ex)
@@ -233,7 +238,9 @@ def get_user_roles() -> List[Role]:
class BaseSupersetView(BaseView):
@staticmethod
- def json_response(obj, status=200) -> Response: # pylint: disable=no-self-use
+ def json_response(
+ obj: Any, status: int = 200
+ ) -> FlaskResponse: # pylint: disable=no-self-use
return Response(
json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True),
status=status,
@@ -241,7 +248,7 @@ class BaseSupersetView(BaseView):
)
-def menu_data():
+def menu_data() -> Dict[str, Any]:
menu = appbuilder.menu.get_data()
root_path = "#"
logo_target_path = ""
@@ -290,7 +297,7 @@ def menu_data():
}
-def common_bootstrap_payload():
+def common_bootstrap_payload() -> Dict[str, Any]:
"""Common data always sent to the client"""
messages = get_flashed_messages(with_categories=True)
locale = str(get_locale())
@@ -335,7 +342,7 @@ class ListWidgetWithCheckboxes(ListWidget): # pylint: disable=too-few-public-me
template = "superset/fab_overrides/list_with_checkboxes.html"
-def validate_json(_form, field):
+def validate_json(form: Form, field: Field) -> None: # pylint: disable=unused-argument
try:
json.loads(field.data)
except Exception as ex:
@@ -352,24 +359,23 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods
yaml_dict_key: Optional[str] = None
@action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download")
- def yaml_export(self, items):
+ def yaml_export(
+ self, items: Union[ImportMixin, List[ImportMixin]]
+ ) -> FlaskResponse:
if not isinstance(items, list):
items = [items]
data = [t.export_to_dict() for t in items]
- if self.yaml_dict_key:
- data = {self.yaml_dict_key: data}
+
return Response(
- yaml.safe_dump(data),
+ yaml.safe_dump({self.yaml_dict_key: data} if self.yaml_dict_key else data),
headers=generate_download_headers("yaml"),
mimetype="application/text",
)
class DeleteMixin: # pylint: disable=too-few-public-methods
- def _delete(
- self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int,
- ) -> None:
+ def _delete(self: BaseView, primary_key: int,) -> None:
"""
Delete function logic, override to implement diferent logic
deletes the record with primary_key = primary_key
@@ -411,7 +417,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
)
- def muldelete(self, items):
+ def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse:
if not items:
abort(404)
for item in items:
@@ -426,7 +432,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
- def apply(self, query, value):
+ def apply(self, query: Query, value: Any) -> Query:
if security_manager.all_datasource_access():
return query
datasource_perms = security_manager.user_view_menu_names("datasource_access")
@@ -497,7 +503,7 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
def bind_field(
- _, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
+ _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
) -> Field:
"""
Customize how fields are bound by stripping all whitespace.
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 5675506..3d40c33 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -16,17 +16,18 @@
# under the License.
import functools
import logging
-from typing import Any, cast, Dict, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from apispec import APISpec
-from flask import Response
-from flask_appbuilder import ModelRestApi
+from flask import Blueprint, Response
+from flask_appbuilder import AppBuilder, Model, ModelRestApi
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import BaseFilter, Filters
from flask_appbuilder.models.sqla.filters import FilterStartsWith
from marshmallow import Schema
from superset.stats_logger import BaseStatsLogger
+from superset.typing import FlaskResponse
from superset.utils.core import time_function
logger = logging.getLogger(__name__)
@@ -40,12 +41,12 @@ get_related_schema = {
}
-def statsd_metrics(f):
+def statsd_metrics(f: Callable) -> Callable:
"""
Handle sending all statsd metrics from the REST API
"""
- def wraps(self, *args: Any, **kwargs: Any) -> Response:
+ def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response:
duration, response = time_function(f, self, *args, **kwargs)
self.send_stats_metrics(response, f.__name__, duration)
return response
@@ -116,6 +117,11 @@ class BaseSupersetModelRestApi(ModelRestApi):
Add extra schemas to the OpenAPI component schemas section
""" # pylint: disable=pointless-string-statement
+ add_columns: List[str]
+ edit_columns: List[str]
+ list_columns: List[str]
+ show_columns: List[str]
+
def __init__(self) -> None:
super().__init__()
self.stats_logger = BaseStatsLogger()
@@ -128,11 +134,13 @@ class BaseSupersetModelRestApi(ModelRestApi):
)
super().add_apispec_components(api_spec)
- def create_blueprint(self, appbuilder, *args, **kwargs):
+ def create_blueprint(
+ self, appbuilder: AppBuilder, *args: Any, **kwargs: Any
+ ) -> Blueprint:
self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"]
return super().create_blueprint(appbuilder, *args, **kwargs)
- def _init_properties(self):
+ def _init_properties(self) -> None:
model_id = self.datamodel.get_pk_name()
if self.list_columns is None and not self.list_model_schema:
self.list_columns = [model_id]
@@ -144,7 +152,9 @@ class BaseSupersetModelRestApi(ModelRestApi):
self.add_columns = [model_id]
super()._init_properties()
- def _get_related_filter(self, datamodel, column_name: str, value: str) -> Filters:
+ def _get_related_filter(
+ self, datamodel: Model, column_name: str, value: str
+ ) -> Filters:
filter_field = self.related_field_filters.get(column_name)
if isinstance(filter_field, str):
filter_field = RelatedFieldFilter(cast(str, filter_field), FilterStartsWith)
@@ -198,7 +208,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
if time_delta:
self.timing_stats("time", key, time_delta)
- def info_headless(self, **kwargs) -> Response:
+ def info_headless(self, **kwargs: Any) -> Response:
"""
Add statsd metrics to builtin FAB _info endpoint
"""
@@ -206,7 +216,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
self.send_stats_metrics(response, self.info.__name__, duration)
return response
- def get_headless(self, pk, **kwargs) -> Response:
+ def get_headless(self, pk: int, **kwargs: Any) -> Response:
"""
Add statsd metrics to builtin FAB GET endpoint
"""
@@ -214,7 +224,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
self.send_stats_metrics(response, self.get.__name__, duration)
return response
- def get_list_headless(self, **kwargs) -> Response:
+ def get_list_headless(self, **kwargs: Any) -> Response:
"""
Add statsd metrics to builtin FAB GET list endpoint
"""
@@ -227,7 +237,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
@safe
@statsd_metrics
@rison(get_related_schema)
- def related(self, column_name: str, **kwargs):
+ def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
"""Get related fields data
---
get:
diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py
index e4795c5..a4436dd 100644
--- a/superset/views/base_schemas.py
+++ b/superset/views/base_schemas.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union
from flask import current_app, g
from flask_appbuilder import Model
@@ -22,7 +22,7 @@ from marshmallow import post_load, pre_load, Schema, ValidationError
from sqlalchemy.orm.exc import NoResultFound
-def validate_owner(value):
+def validate_owner(value: int) -> None:
try:
(
current_app.appbuilder.get_session.query(
@@ -44,18 +44,25 @@ class BaseSupersetSchema(Schema):
__class_model__: Model = None
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any) -> None:
self.instance: Optional[Model] = None
super().__init__(**kwargs)
- def load(
- self, data, many=None, partial=None, instance: Model = None, **kwargs
- ): # pylint: disable=arguments-differ
+ def load( # pylint: disable=arguments-differ
+ self,
+ data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
+ many: Optional[bool] = None,
+ partial: Optional[Union[bool, Sequence[str], Set[str]]] = None,
+ instance: Optional[Model] = None,
+ **kwargs: Any,
+ ) -> Any:
self.instance = instance
return super().load(data, many=many, partial=partial, **kwargs)
@post_load
- def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
+ def make_object(
+ self, data: Dict[Any, Any], discard: Optional[List[str]] = None
+ ) -> Model:
"""
Creates a Model object from POST or PUT requests. PUT will use self.instance
previously fetched from the endpoint handler
@@ -92,13 +99,13 @@ class BaseOwnedSchema(BaseSupersetSchema):
return instance
@pre_load
- def pre_load(self, data: Dict):
+ def pre_load(self, data: Dict[Any, Any]) -> None:
# if PUT request don't set owners to empty list
if not self.instance:
data[self.owners_field_name] = data.get(self.owners_field_name, [])
@staticmethod
- def set_owners(instance: Model, owners: List[int]):
+ def set_owners(instance: Model, owners: List[int]) -> None:
owner_objs = list()
if g.user.id not in owners:
owners.append(g.user.id)
diff --git a/superset/views/core.py b/superset/views/core.py
index 3d89334..c561222 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -33,6 +33,7 @@ from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access, has_access_api
from flask_appbuilder.security.sqla import models as ab_models
+from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import and_, Integer, or_, select
from sqlalchemy.engine.url import make_url
@@ -64,7 +65,12 @@ from superset import (
viz,
)
from superset.connectors.connector_registry import ConnectorRegistry
-from superset.connectors.sqla.models import AnnotationDatasource
+from superset.connectors.sqla.models import (
+ AnnotationDatasource,
+ SqlaTable,
+ SqlMetric,
+ TableColumn,
+)
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
@@ -88,12 +94,14 @@ from superset.security.analytics_db_safety import (
)
from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
+from superset.typing import FlaskResponse
from superset.utils import core as utils, dashboard_import_export
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
from superset.utils.dates import now_as_float
from superset.utils.decorators import etag_cache, stats_timing
from superset.views.database.filters import DatabaseFilter
from superset.views.utils import get_dashboard_extra_filters
+from superset.viz import BaseViz
from .base import (
api,
@@ -157,7 +165,7 @@ if not config["ENABLE_JAVASCRIPT_CONTROLS"]:
FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"]
-def get_database_access_error_msg(database_name):
+def get_database_access_error_msg(database_name: str) -> str:
return __(
"This view requires the database %(name)s or "
"`all_datasource_access` permission",
@@ -165,13 +173,15 @@ def get_database_access_error_msg(database_name):
)
-def is_owner(obj, user):
+def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
""" Check if user is owner of the slice """
return obj and user in obj.owners
def check_datasource_perms(
- self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
+ self: "Superset",
+ datasource_type: Optional[str] = None,
+ datasource_id: Optional[int] = None,
) -> None:
"""
Check if user can access a cached response from explore_json.
@@ -218,7 +228,7 @@ def check_datasource_perms(
security_manager.assert_viz_permission(viz_obj)
-def check_slice_perms(self, slice_id):
+def check_slice_perms(self: "Superset", slice_id: int) -> None:
"""
Check if user can access a cached response from slice_json.
@@ -228,19 +238,20 @@ def check_slice_perms(self, slice_id):
form_data, slc = get_form_data(slice_id, use_slice_data=True)
- viz_obj = get_viz(
- datasource_type=slc.datasource.type,
- datasource_id=slc.datasource.id,
- form_data=form_data,
- force=False,
- )
+ if slc:
+ viz_obj = get_viz(
+ datasource_type=slc.datasource.type,
+ datasource_id=slc.datasource.id,
+ form_data=form_data,
+ force=False,
+ )
- security_manager.assert_viz_permission(viz_obj)
+ security_manager.assert_viz_permission(viz_obj)
def _deserialize_results_payload(
- payload: Union[bytes, str], query, use_msgpack: Optional[bool] = False
-) -> dict:
+ payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
+) -> Dict[Any, Any]:
logger.debug(f"Deserializing from msgpack: {use_msgpack}")
if use_msgpack:
with stats_timing(
@@ -305,19 +316,19 @@ class AccessRequestsModelView(SupersetModelView, DeleteMixin):
@talisman(force_https=False)
@app.route("/health")
-def health():
+def health() -> FlaskResponse:
return "OK"
@talisman(force_https=False)
@app.route("/healthcheck")
-def healthcheck():
+def healthcheck() -> FlaskResponse:
return "OK"
@talisman(force_https=False)
@app.route("/ping")
-def ping():
+def ping() -> FlaskResponse:
return "OK"
@@ -328,26 +339,26 @@ class KV(BaseSupersetView):
@event_logger.log_this
@has_access_api
@expose("/store/", methods=["POST"])
- def store(self):
+ def store(self) -> FlaskResponse:
try:
value = request.form.get("data")
obj = models.KeyValue(value=value)
db.session.add(obj)
db.session.commit()
except Exception as ex:
- return json_error_response(ex)
+ return json_error_response(utils.error_msg_from_exception(ex))
return Response(json.dumps({"id": obj.id}), status=200)
@event_logger.log_this
@has_access_api
- @expose("/<key_id>/", methods=["GET"])
- def get_value(self, key_id):
+ @expose("/<int:key_id>/", methods=["GET"])
+ def get_value(self, key_id: int) -> FlaskResponse:
try:
kv = db.session.query(models.KeyValue).filter_by(id=key_id).scalar()
if not kv:
return Response(status=404, content_type="text/plain")
except Exception as ex:
- return json_error_response(ex)
+ return json_error_response(utils.error_msg_from_exception(ex))
return Response(kv.value, status=200, content_type="text/plain")
@@ -356,8 +367,8 @@ class R(BaseSupersetView):
"""used for short urls"""
@event_logger.log_this
- @expose("/<url_id>")
- def index(self, url_id):
+ @expose("/<int:url_id>")
+ def index(self, url_id: int) -> FlaskResponse:
url = db.session.query(models.Url).get(url_id)
if url and url.url:
explore_url = "//superset/explore/?"
@@ -373,7 +384,7 @@ class R(BaseSupersetView):
@event_logger.log_this
@has_access_api
@expose("/shortner/", methods=["POST"])
- def shortner(self):
+ def shortner(self) -> FlaskResponse:
url = request.form.get("data")
obj = models.Url(url=url)
db.session.add(obj)
@@ -393,15 +404,21 @@ class Superset(BaseSupersetView):
@has_access_api
@expose("/datasources/")
- def datasources(self):
- datasources = ConnectorRegistry.get_all_datasources(db.session)
- datasources = [o.short_data for o in datasources if o.short_data.get("name")]
- datasources = sorted(datasources, key=lambda o: o["name"])
- return self.json_response(datasources)
+ def datasources(self) -> FlaskResponse:
+ return self.json_response(
+ sorted(
+ [
+ datasource.short_data
+ for datasource in ConnectorRegistry.get_all_datasources(db.session)
+ if datasource.short_data.get("name")
+ ],
+ key=lambda datasource: datasource["name"],
+ )
+ )
@has_access_api
@expose("/override_role_permissions/", methods=["POST"])
- def override_role_permissions(self):
+ def override_role_permissions(self) -> FlaskResponse:
"""Updates the role with the give datasource permissions.
Permissions not in the request will be revoked. This endpoint should
@@ -454,7 +471,7 @@ class Superset(BaseSupersetView):
@event_logger.log_this
@has_access
@expose("/request_access/")
- def request_access(self):
+ def request_access(self) -> FlaskResponse:
datasources = set()
dashboard_id = request.args.get("dashboard_id")
if dashboard_id:
@@ -462,7 +479,7 @@ class Superset(BaseSupersetView):
datasources |= dash.datasources
datasource_id = request.args.get("datasource_id")
datasource_type = request.args.get("datasource_type")
- if datasource_id:
+ if datasource_id and datasource_type:
ds_class = ConnectorRegistry.sources.get(datasource_type)
datasource = (
db.session.query(ds_class).filter_by(id=int(datasource_id)).one()
@@ -497,8 +514,8 @@ class Superset(BaseSupersetView):
@event_logger.log_this
@has_access
@expose("/approve")
- def approve(self):
- def clean_fulfilled_requests(session):
+ def approve(self) -> FlaskResponse:
+ def clean_fulfilled_requests(session: Session) -> None:
for r in session.query(DAR).all():
datasource = ConnectorRegistry.get_datasource(
r.datasource_type, r.datasource_id, session
@@ -508,8 +525,8 @@ class Superset(BaseSupersetView):
session.delete(r)
session.commit()
- datasource_type = request.args.get("datasource_type")
- datasource_id = request.args.get("datasource_id")
+ datasource_type = request.args["datasource_type"]
+ datasource_id = request.args["datasource_id"]
created_by_username = request.args.get("created_by")
role_to_grant = request.args.get("role_to_grant")
role_to_extend = request.args.get("role_to_extend")
@@ -598,8 +615,8 @@ class Superset(BaseSupersetView):
return redirect("/accessrequestsmodelview/list/")
@has_access
- @expose("/slice/<slice_id>/")
- def slice(self, slice_id):
+ @expose("/slice/<int:slice_id>/")
+ def slice(self, slice_id: int) -> FlaskResponse:
form_data, slc = get_form_data(slice_id, use_slice_data=True)
if not slc:
abort(404)
@@ -611,15 +628,16 @@ class Superset(BaseSupersetView):
endpoint += f"&{param}=true"
return redirect(endpoint)
- def get_query_string_response(self, viz_obj):
+ def get_query_string_response(self, viz_obj: BaseViz) -> FlaskResponse:
query = None
try:
query_obj = viz_obj.query_obj()
if query_obj:
query = viz_obj.datasource.get_query_str(query_obj)
except Exception as ex:
- logger.exception(ex)
- return json_error_response(ex)
+ err_msg = utils.error_msg_from_exception(ex)
+ logger.exception(err_msg)
+ return json_error_response(err_msg)
if not query:
query = "No query."
@@ -628,15 +646,17 @@ class Superset(BaseSupersetView):
{"query": query, "language": viz_obj.datasource.query_language}
)
- def get_raw_results(self, viz_obj):
+ def get_raw_results(self, viz_obj: BaseViz) -> FlaskResponse:
return self.json_response(
{"data": viz_obj.get_df_payload()["df"].to_dict("records")}
)
- def get_samples(self, viz_obj):
+ def get_samples(self, viz_obj: BaseViz) -> FlaskResponse:
return self.json_response({"data": viz_obj.get_samples()})
- def generate_json(self, viz_obj, response_type: Optional[str] = None) -> Response:
+ def generate_json(
+ self, viz_obj: BaseViz, response_type: Optional[str] = None
+ ) -> FlaskResponse:
if response_type == utils.ChartDataResultFormat.CSV:
return CsvResponse(
viz_obj.get_csv(),
@@ -660,16 +680,16 @@ class Superset(BaseSupersetView):
@event_logger.log_this
@api
@has_access_api
- @expose("/slice_json/<slice_id>")
+ @expose("/slice_json/<int:slice_id>")
@etag_cache(CACHE_DEFAULT_TIMEOUT, check_perms=check_slice_perms)
- def slice_json(self, slice_id):
+ def slice_json(self, slice_id: int) -> FlaskResponse:
form_data, slc = get_form_data(slice_id, use_slice_data=True)
- datasource_type = slc.datasource.type
- datasource_id = slc.datasource.id
+ if not slc:
+ return json_error_response("The slice does not exist")
try:
viz_obj = get_viz(
- datasource_type=datasource_type,
- datasource_id=datasource_id,
+ datasource_type=slc.datasource.type,
+ datasource_id=slc.datasource.id,
form_data=form_data,
force=False,
)
@@ -680,8 +700,8 @@ class Superset(BaseSupersetView):
@event_logger.log_this
@api
@has_access_api
- @expose("/annotation_json/<layer_id>")
- def annotation_json(self, layer_id):
+ @expose("/annotation_json/<int:layer_id>")
+ def annotation_json(self, layer_id: int) -> FlaskResponse:
form_data = get_form_data()[0]
form_data["layer_id"] = layer_id
form_data["filters"] = [{"col": "layer_id", "op": "==", "val": layer_id}]
@@ -714,11 +734,14 @@ class Superset(BaseSupersetView):
@has_access_api
@handle_api_exception
@expose(
- "/explore_json/<datasource_type>/<datasource_id>/", methods=EXPLORE_JSON_METHODS
+ "/explore_json/<datasource_type>/<int:datasource_id>/",
+ methods=EXPLORE_JSON_METHODS,
)
@expose("/explore_json/", methods=EXPLORE_JSON_METHODS)
@etag_cache(CACHE_DEFAULT_TIMEOUT, check_perms=check_datasource_perms)
- def explore_json(self, datasource_type=None, datasource_id=None):
+ def explore_json(
+ self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
+ ) -> FlaskResponse:
"""Serves all request that GET or POST form_data
This endpoint evolved to be the entry point of many different
@@ -729,7 +752,9 @@ class Superset(BaseSupersetView):
TODO: break into one endpoint for each return shape"""
response_type = utils.ChartDataResultFormat.JSON.value
- responses = [resp_format for resp_format in utils.ChartDataResultFormat]
+ responses: List[
+ Union[utils.ChartDataResultFormat, utils.ChartDataResultType]
+ ] = [resp_format for resp_format in utils.ChartDataResultFormat]
responses.extend([resp_type for resp_type in utils.ChartDataResultType])
for response_option in responses:
if request.args.get(response_option) == "true":
@@ -744,7 +769,7 @@ class Superset(BaseSupersetView):
)
viz_obj = get_viz(
- datasource_type=datasource_type,
+ datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=request.args.get("force") == "true",
@@ -757,7 +782,7 @@ class Superset(BaseSupersetView):
@event_logger.log_this
@has_access
@expose("/import_dashboards", methods=["GET", "POST"])
- def import_dashboards(self):
+ def import_dashboards(self) -> FlaskResponse:
"""Overrides the dashboards using json instances from the file."""
f = request.files.get("file")
if request.method == "POST" and f:
@@ -788,9 +813,11 @@ class Superset(BaseSupersetView):
@event_logger.log_this
@has_access
- @expose("/explore/<datasource_type>/<datasource_id>/", methods=["GET", "POST"])
+ @expose("/explore/<datasource_type>/<int:datasource_id>/", methods=["GET", "POST"])
@expose("/explore/", methods=["GET", "POST"])
- def explore(self, datasource_type=None, datasource_id=None):
+ def explore(
+ self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
+ ) -> FlaskResponse:
user_id = g.user.get_id() if g.user else None
form_data, slc = get_form_data(use_slice_data=True)
@@ -834,7 +861,7 @@ class Superset(BaseSupersetView):
return redirect(error_redirect)
datasource = ConnectorRegistry.get_datasource(
- datasource_type, datasource_id, db.session
+ cast(str, datasource_type), datasource_id, db.session
)
if not datasource:
flash(DATASOURCE_MISSING_ERR, "danger")
@@ -859,12 +886,12 @@ class Superset(BaseSupersetView):
# slc perms
slice_add_perm = security_manager.can_access("can_add", "SliceModelView")
- slice_overwrite_perm = is_owner(slc, g.user)
+ slice_overwrite_perm = is_owner(slc, g.user) if slc else False
slice_download_perm = security_manager.can_access(
"can_download", "SliceModelView"
)
- form_data["datasource"] = str(datasource_id) + "__" + datasource_type
+ form_data["datasource"] = str(datasource_id) + "__" + cast(str, datasource_type)
# On explore, merge legacy and extra filters into the form data
utils.convert_legacy_filters_into_adhoc(form_data)
@@ -890,14 +917,16 @@ class Superset(BaseSupersetView):
)
if action in ("saveas", "overwrite"):
+ if not slc:
+ return json_error_response("The slice does not exist")
+
return self.save_or_overwrite_slice(
- request.args,
slc,
slice_add_perm,
slice_overwrite_perm,
slice_download_perm,
datasource_id,
- datasource_type,
+ cast(str, datasource_type),
datasource.name,
)
@@ -940,8 +969,10 @@ class Superset(BaseSupersetView):
@api
@handle_api_exception
@has_access_api
- @expose("/filter/<datasource_type>/<datasource_id>/<column>/")
- def filter(self, datasource_type, datasource_id, column):
+ @expose("/filter/<datasource_type>/<int:datasource_id>/<column>/")
+ def filter(
+ self, datasource_type: str, datasource_id: int, column: str
+ ) -> FlaskResponse:
"""
Endpoint to retrieve values for specified column.
@@ -965,28 +996,27 @@ class Superset(BaseSupersetView):
return json_success(payload)
@staticmethod
- def remove_extra_filters(filters):
+ def remove_extra_filters(filters: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extra filters are ones inherited from the dashboard's temporary context
Those should not be saved when saving the chart"""
return [f for f in filters if not f.get("isExtra")]
def save_or_overwrite_slice(
self,
- args,
- slc,
- slice_add_perm,
- slice_overwrite_perm,
- slice_download_perm,
- datasource_id,
- datasource_type,
- datasource_name,
- ):
+ slc: Slice,
+ slice_add_perm: bool,
+ slice_overwrite_perm: bool,
+ slice_download_perm: bool,
+ datasource_id: int,
+ datasource_type: str,
+ datasource_name: str,
+ ) -> FlaskResponse:
"""Save or overwrite a slice"""
- slice_name = args.get("slice_name")
- action = args.get("action")
+ slice_name = request.args.get("slice_name")
+ action = request.args.get("action")
form_data = get_form_data()[0]
- if action in ("saveas"):
+ if action == "saveas":
if "slice_id" in form_data:
form_data.pop("slice_id") # don't save old slice_id
slc = Slice(owners=[g.user] if g.user else [])
@@ -1002,18 +1032,20 @@ class Superset(BaseSupersetView):
slc.datasource_id = datasource_id
slc.slice_name = slice_name
- if action in ("saveas") and slice_add_perm:
+ if action == "saveas" and slice_add_perm:
self.save_slice(slc)
elif action == "overwrite" and slice_overwrite_perm:
self.overwrite_slice(slc)
# Adding slice to a dashboard if requested
- dash = None
+ dash: Optional[Dashboard] = None
+
if request.args.get("add_to_dash") == "existing":
- dash = (
+ dash = cast(
+ Dashboard,
db.session.query(Dashboard)
- .filter_by(id=int(request.args.get("save_to_dashboard_id")))
- .one()
+ .filter_by(id=int(request.args["save_to_dashboard_id"]))
+ .one(),
)
# check edit dashboard permissions
dash_overwrite_perm = check_ownership(dash, raise_if_false=False)
@@ -1066,19 +1098,19 @@ class Superset(BaseSupersetView):
"dashboard_id": dash.id if dash else None,
}
- if request.args.get("goto_dash") == "true":
+ if dash and request.args.get("goto_dash") == "true":
response.update({"dashboard": dash.url})
return json_success(json.dumps(response))
- def save_slice(self, slc):
+ def save_slice(self, slc: Slice) -> None:
session = db.session()
msg = _("Chart [{}] has been saved").format(slc.slice_name)
session.add(slc)
session.commit()
flash(msg, "info")
- def overwrite_slice(self, slc):
+ def overwrite_slice(self, slc: Slice) -> None:
session = db.session()
session.merge(slc)
session.commit()
@@ -1087,17 +1119,16 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/schemas/<db_id>/")
- @expose("/schemas/<db_id>/<force_refresh>/")
- def schemas(self, db_id, force_refresh="false"):
+ @expose("/schemas/<int:db_id>/")
+ @expose("/schemas/<int:db_id>/<force_refresh>/")
+ def schemas(self, db_id: int, force_refresh: str = "false") -> FlaskResponse:
db_id = int(db_id)
- force_refresh = force_refresh.lower() == "true"
database = db.session.query(models.Database).get(db_id)
if database:
schemas = database.get_all_schema_names(
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
- force=force_refresh,
+ force=force_refresh.lower() == "true",
)
schemas = security_manager.schemas_accessible_by_user(database, schemas)
else:
@@ -1111,7 +1142,7 @@ class Superset(BaseSupersetView):
@expose("/tables/<int:db_id>/<schema>/<substr>/<force_refresh>/")
def tables(
self, db_id: int, schema: str, substr: str, force_refresh: str = "false"
- ):
+ ) -> FlaskResponse:
"""Endpoint to fetch the list of tables for given database"""
# Guarantees database filtering by security access
query = db.session.query(models.Database)
@@ -1211,11 +1242,11 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/copy_dash/<dashboard_id>/", methods=["GET", "POST"])
- def copy_dash(self, dashboard_id):
+ @expose("/copy_dash/<int:dashboard_id>/", methods=["GET", "POST"])
+ def copy_dash(self, dashboard_id: int) -> FlaskResponse:
"""Copy dashboard"""
session = db.session()
- data = json.loads(request.form.get("data"))
+ data = json.loads(request.form["data"])
dash = models.Dashboard()
original_dash = session.query(Dashboard).get(dashboard_id)
@@ -1235,12 +1266,8 @@ class Superset(BaseSupersetView):
# update chartId of layout entities
for value in data["positions"].values():
- if (
- isinstance(value, dict)
- and value.get("meta")
- and value.get("meta").get("chartId")
- ):
- old_id = value.get("meta").get("chartId")
+ if isinstance(value, dict) and value.get("meta", {}).get("chartId"):
+ old_id = value["meta"]["chartId"]
new_id = old_to_new_slice_ids[old_id]
value["meta"]["chartId"] = new_id
else:
@@ -1257,13 +1284,13 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/save_dash/<dashboard_id>/", methods=["GET", "POST"])
- def save_dash(self, dashboard_id):
+ @expose("/save_dash/<int:dashboard_id>/", methods=["GET", "POST"])
+ def save_dash(self, dashboard_id: int) -> FlaskResponse:
"""Save a dashboard's metadata"""
session = db.session()
dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
- data = json.loads(request.form.get("data"))
+ data = json.loads(request.form["data"])
self._set_dash_metadata(dash, data)
session.merge(dash)
session.commit()
@@ -1272,8 +1299,10 @@ class Superset(BaseSupersetView):
@staticmethod
def _set_dash_metadata(
- dashboard, data, old_to_new_slice_ids: Optional[Dict[int, int]] = None
- ):
+ dashboard: Dashboard,
+ data: Dict[Any, Any],
+ old_to_new_slice_ids: Optional[Dict[int, int]] = None,
+ ) -> None:
positions = data["positions"]
# find slices in the position data
slice_ids = []
@@ -1352,10 +1381,10 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/add_slices/<dashboard_id>/", methods=["POST"])
- def add_slices(self, dashboard_id):
+ @expose("/add_slices/<int:dashboard_id>/", methods=["POST"])
+ def add_slices(self, dashboard_id: int) -> FlaskResponse:
"""Add and save slices to a dashboard"""
- data = json.loads(request.form.get("data"))
+ data = json.loads(request.form["data"])
session = db.session()
dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
@@ -1369,7 +1398,7 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/testconn", methods=["POST", "GET"])
- def testconn(self):
+ def testconn(self) -> FlaskResponse:
"""Tests a sqla connection"""
db_name = request.json.get("name")
uri = request.json.get("uri")
@@ -1443,13 +1472,13 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/recent_activity/<user_id>/", methods=["GET"])
- def recent_activity(self, user_id):
+ @expose("/recent_activity/<int:user_id>/", methods=["GET"])
+ def recent_activity(self, user_id: int) -> FlaskResponse:
"""Recent activity (actions) for a given user"""
M = models
if request.args.get("limit"):
- limit = int(request.args.get("limit"))
+ limit = int(request.args["limit"])
else:
limit = 1000
@@ -1490,7 +1519,7 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/csrf_token/", methods=["GET"])
- def csrf_token(self):
+ def csrf_token(self) -> FlaskResponse:
return Response(
self.render_template("superset/csrf_token.json"), mimetype="text/json"
)
@@ -1498,7 +1527,7 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/available_domains/", methods=["GET"])
- def available_domains(self):
+ def available_domains(self) -> FlaskResponse:
"""
Returns the list of available Superset Webserver domains (if any)
defined in config. This enables charts embedded in other apps to
@@ -1511,15 +1540,15 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/fave_dashboards_by_username/<username>/", methods=["GET"])
- def fave_dashboards_by_username(self, username):
+ def fave_dashboards_by_username(self, username: str) -> FlaskResponse:
"""This lets us use a user's username to pull favourite dashboards"""
user = security_manager.find_user(username=username)
return self.fave_dashboards(user.get_id())
@api
@has_access_api
- @expose("/fave_dashboards/<user_id>/", methods=["GET"])
- def fave_dashboards(self, user_id):
+ @expose("/fave_dashboards/<int:user_id>/", methods=["GET"])
+ def fave_dashboards(self, user_id: int) -> FlaskResponse:
qry = (
db.session.query(Dashboard, models.FavStar.dttm)
.join(
@@ -1550,8 +1579,8 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/created_dashboards/<user_id>/", methods=["GET"])
- def created_dashboards(self, user_id):
+ @expose("/created_dashboards/<int:user_id>/", methods=["GET"])
+ def created_dashboards(self, user_id: int) -> FlaskResponse:
Dash = Dashboard
qry = (
db.session.query(Dash)
@@ -1573,8 +1602,8 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/user_slices", methods=["GET"])
- @expose("/user_slices/<user_id>/", methods=["GET"])
- def user_slices(self, user_id=None):
+ @expose("/user_slices/<int:user_id>/", methods=["GET"])
+ def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
"""List of slices a user created, or faved"""
if not user_id:
user_id = g.user.id
@@ -1584,7 +1613,7 @@ class Superset(BaseSupersetView):
.join(
models.FavStar,
and_(
- models.FavStar.user_id == int(user_id),
+ models.FavStar.user_id == user_id,
models.FavStar.class_name == "slice",
Slice.id == models.FavStar.obj_id,
),
@@ -1615,8 +1644,8 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/created_slices", methods=["GET"])
- @expose("/created_slices/<user_id>/", methods=["GET"])
- def created_slices(self, user_id=None):
+ @expose("/created_slices/<int:user_id>/", methods=["GET"])
+ def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
"""List of slices created by this user"""
if not user_id:
user_id = g.user.id
@@ -1640,8 +1669,8 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/fave_slices", methods=["GET"])
- @expose("/fave_slices/<user_id>/", methods=["GET"])
- def fave_slices(self, user_id=None):
+ @expose("/fave_slices/<int:user_id>/", methods=["GET"])
+ def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
"""Favorite slices for a user"""
if not user_id:
user_id = g.user.id
@@ -1650,7 +1679,7 @@ class Superset(BaseSupersetView):
.join(
models.FavStar,
and_(
- models.FavStar.user_id == int(user_id),
+ models.FavStar.user_id == user_id,
models.FavStar.class_name == "slice",
Slice.id == models.FavStar.obj_id,
),
@@ -1677,12 +1706,11 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/warm_up_cache/", methods=["GET"])
- def warm_up_cache(self):
+ def warm_up_cache(self) -> FlaskResponse:
"""Warms up the cache for the slice or table.
Note for slices a force refresh occurs.
"""
- slices = None
session = db.session()
slice_id = request.args.get("slice_id")
dashboard_id = request.args.get("dashboard_id")
@@ -1704,7 +1732,6 @@ class Superset(BaseSupersetView):
__("Chart %(id)s not found", id=slice_id), status=404
)
elif table_name and db_name:
- SqlaTable = ConnectorRegistry.sources["table"]
table = (
session.query(SqlaTable)
.join(models.Database)
@@ -1761,8 +1788,8 @@ class Superset(BaseSupersetView):
return json_success(json.dumps(result))
@has_access_api
- @expose("/favstar/<class_name>/<obj_id>/<action>/")
- def favstar(self, class_name, obj_id, action):
+ @expose("/favstar/<class_name>/<int:obj_id>/<action>/")
+ def favstar(self, class_name: str, obj_id: int, action: str) -> FlaskResponse:
"""Toggle favorite stars on Slices and Dashboard"""
session = db.session()
FavStar = models.FavStar
@@ -1793,8 +1820,8 @@ class Superset(BaseSupersetView):
@api
@has_access_api
- @expose("/dashboard/<dashboard_id>/published/", methods=("GET", "POST"))
- def publish(self, dashboard_id):
+ @expose("/dashboard/<int:dashboard_id>/published/", methods=("GET", "POST"))
+ def publish(self, dashboard_id: int) -> FlaskResponse:
"""Gets and toggles published status on dashboards"""
logger.warning(
"This API endpoint is deprecated and will be removed in version 1.0.0"
@@ -1827,15 +1854,15 @@ class Superset(BaseSupersetView):
return json_success(json.dumps({"published": dash.published}))
@has_access
- @expose("/dashboard/<dashboard_id>/")
- def dashboard(self, dashboard_id):
+ @expose("/dashboard/<dashboard_id_or_slug>/")
+ def dashboard(self, dashboard_id_or_slug: str) -> FlaskResponse:
"""Server side rendering for a dashboard"""
session = db.session()
qry = session.query(Dashboard)
- if dashboard_id.isdigit():
- qry = qry.filter_by(id=int(dashboard_id))
+ if dashboard_id_or_slug.isdigit():
+ qry = qry.filter_by(id=int(dashboard_id_or_slug))
else:
- qry = qry.filter_by(slug=dashboard_id)
+ qry = qry.filter_by(slug=dashboard_id_or_slug)
dash = qry.one_or_none()
if not dash:
@@ -1885,7 +1912,7 @@ class Superset(BaseSupersetView):
# Hack to log the dashboard_id properly, even when getting a slug
@event_logger.log_this
- def dashboard(**kwargs):
+ def dashboard(**kwargs: Any) -> None:
pass
dashboard(
@@ -1939,13 +1966,13 @@ class Superset(BaseSupersetView):
@api
@event_logger.log_this
@expose("/log/", methods=["POST"])
- def log(self):
+ def log(self) -> FlaskResponse:
return Response(status=200)
@has_access
@expose("/sync_druid/", methods=["POST"])
@event_logger.log_this
- def sync_druid_source(self):
+ def sync_druid_source(self) -> FlaskResponse:
"""Syncs the druid datasource in main db with the provided config.
The endpoint takes 3 arguments:
@@ -1996,14 +2023,15 @@ class Superset(BaseSupersetView):
try:
DruidDatasource.sync_to_db_from_config(druid_config, user, cluster)
except Exception as ex:
- logger.exception(utils.error_msg_from_exception(ex))
- return json_error_response(utils.error_msg_from_exception(ex))
+ err_msg = utils.error_msg_from_exception(ex)
+ logger.exception(err_msg)
+ return json_error_response(err_msg)
return Response(status=201)
@has_access
@expose("/get_or_create_table/", methods=["POST"])
@event_logger.log_this
- def sqllab_table_viz(self):
+ def sqllab_table_viz(self) -> FlaskResponse:
""" Gets or creates a table object with attributes passed to the API.
It expects the json with params:
@@ -2013,10 +2041,9 @@ class Superset(BaseSupersetView):
* templateParams - params for the Jinja templating syntax, optional
:return: Response
"""
- SqlaTable = ConnectorRegistry.sources["table"]
- data = json.loads(request.form.get("data"))
- table_name = data.get("datasourceName")
- database_id = data.get("dbId")
+ data = json.loads(request.form["data"])
+ table_name = data["datasourceName"]
+ database_id = data["dbId"]
table = (
db.session.query(SqlaTable)
.filter_by(database_id=database_id, table_name=table_name)
@@ -2045,11 +2072,10 @@ class Superset(BaseSupersetView):
@has_access
@expose("/sqllab_viz/", methods=["POST"])
@event_logger.log_this
- def sqllab_viz(self):
- SqlaTable = ConnectorRegistry.sources["table"]
- data = json.loads(request.form.get("data"))
- table_name = data.get("datasourceName")
- database_id = data.get("dbId")
+ def sqllab_viz(self) -> FlaskResponse:
+ data = json.loads(request.form["data"])
+ table_name = data["datasourceName"]
+ database_id = data["dbId"]
table = (
db.session.query(SqlaTable)
.filter_by(database_id=database_id, table_name=table_name)
@@ -2067,9 +2093,6 @@ class Superset(BaseSupersetView):
cols = []
for config in data.get("columns"):
column_name = config.get("name")
- SqlaTable = ConnectorRegistry.sources["table"]
- TableColumn = SqlaTable.column_class
- SqlMetric = SqlaTable.metric_class
col = TableColumn(
column_name=column_name,
filterable=True,
@@ -2085,20 +2108,24 @@ class Superset(BaseSupersetView):
return json_success(json.dumps({"table_id": table.id}))
@has_access
- @expose("/extra_table_metadata/<database_id>/<table_name>/<schema>/")
+ @expose("/extra_table_metadata/<int:database_id>/<table_name>/<schema>/")
@event_logger.log_this
- def extra_table_metadata(self, database_id, table_name, schema):
- schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
- table_name = utils.parse_js_uri_path_item(table_name)
+ def extra_table_metadata(
+ self, database_id: int, table_name: str, schema: str
+ ) -> FlaskResponse:
+ schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore
+ table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
mydb = db.session.query(models.Database).filter_by(id=database_id).one()
payload = mydb.db_engine_spec.extra_table_metadata(mydb, table_name, schema)
return json_success(json.dumps(payload))
@has_access
- @expose("/select_star/<database_id>/<table_name>")
- @expose("/select_star/<database_id>/<table_name>/<schema>")
+ @expose("/select_star/<int:database_id>/<table_name>")
+ @expose("/select_star/<int:database_id>/<table_name>/<schema>")
@event_logger.log_this
- def select_star(self, database_id, table_name, schema=None):
+ def select_star(
+ self, database_id: int, table_name: str, schema: Optional[str] = None
+ ) -> FlaskResponse:
logging.warning(
f"{self.__class__.__name__}.select_star "
"This API endpoint is deprecated and will be removed in version 1.0.0"
@@ -2110,8 +2137,8 @@ class Superset(BaseSupersetView):
f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
)
return json_error_response("Not found", 404)
- schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
- table_name = utils.parse_js_uri_path_item(table_name)
+ schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore
+ table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(
database, Table(table_name, schema), schema
@@ -2132,12 +2159,12 @@ class Superset(BaseSupersetView):
)
@has_access_api
- @expose("/estimate_query_cost/<database_id>/", methods=["POST"])
- @expose("/estimate_query_cost/<database_id>/<schema>/", methods=["POST"])
+ @expose("/estimate_query_cost/<int:database_id>/", methods=["POST"])
+ @expose("/estimate_query_cost/<int:database_id>/<schema>/", methods=["POST"])
@event_logger.log_this
def estimate_query_cost(
self, database_id: int, schema: Optional[str] = None
- ) -> Response:
+ ) -> FlaskResponse:
mydb = db.session.query(models.Database).get(database_id)
sql = json.loads(request.form.get("sql", '""'))
@@ -2157,7 +2184,7 @@ class Superset(BaseSupersetView):
logger.exception(ex)
return json_error_response(timeout_msg)
except Exception as ex:
- return json_error_response(str(ex))
+ return json_error_response(utils.error_msg_from_exception(ex))
spec = mydb.db_engine_spec
query_cost_formatters: Dict[str, Any] = get_feature_flags().get(
@@ -2171,16 +2198,16 @@ class Superset(BaseSupersetView):
return json_success(json.dumps(cost))
@expose("/theme/")
- def theme(self):
+ def theme(self) -> FlaskResponse:
return self.render_template("superset/theme.html")
@has_access_api
@expose("/results/<key>/")
@event_logger.log_this
- def results(self, key):
+ def results(self, key: str) -> FlaskResponse:
return self.results_exec(key)
- def results_exec(self, key: str):
+ def results_exec(self, key: str) -> FlaskResponse:
"""Serves a key off of the results backend
It is possible to pass the `rows` query argument to limit the number
@@ -2244,7 +2271,7 @@ class Superset(BaseSupersetView):
on_giveup=lambda details: db.session.rollback(),
max_tries=5,
)
- def stop_query(self):
+ def stop_query(self) -> FlaskResponse:
client_id = request.form.get("client_id")
query = db.session.query(Query).filter_by(client_id=client_id).one()
@@ -2265,12 +2292,12 @@ class Superset(BaseSupersetView):
@has_access_api
@expose("/validate_sql_json/", methods=["POST", "GET"])
@event_logger.log_this
- def validate_sql_json(self):
+ def validate_sql_json(self) -> FlaskResponse:
"""Validates that arbitrary sql is acceptable for the given database.
Returns a list of error/warning annotations as json.
"""
- sql = request.form.get("sql")
- database_id = request.form.get("database_id")
+ sql = request.form["sql"]
+ database_id = request.form["database_id"]
schema = request.form.get("schema") or None
template_params = json.loads(request.form.get("templateParams") or "{}")
@@ -2338,9 +2365,9 @@ class Superset(BaseSupersetView):
query: Query,
expand_data: bool,
log_params: Optional[Dict[str, Any]] = None,
- ) -> Response:
+ ) -> FlaskResponse:
"""
- Send SQL JSON query to celery workers
+ Send SQL JSON query to celery workers.
:param session: SQLAlchemy session object
:param rendered_query: the rendered query to perform by workers
@@ -2389,9 +2416,9 @@ class Superset(BaseSupersetView):
query: Query,
expand_data: bool,
log_params: Optional[Dict[str, Any]] = None,
- ) -> Response:
+ ) -> FlaskResponse:
"""
- Execute SQL query (sql json)
+ Execute SQL query (sql json).
:param rendered_query: The rendered query (included templates)
:param query: The query SQL (SQLAlchemy) object
@@ -2424,7 +2451,7 @@ class Superset(BaseSupersetView):
)
except Exception as ex:
logger.exception(f"Query {query.id}: {ex}")
- return json_error_response(f"{{e}}")
+ return json_error_response(utils.error_msg_from_exception(ex))
if data.get("status") == QueryStatus.FAILED:
return json_error_response(payload=data)
return json_success(payload)
@@ -2432,15 +2459,15 @@ class Superset(BaseSupersetView):
@has_access_api
@expose("/sql_json/", methods=["POST"])
@event_logger.log_this
- def sql_json(self):
+ def sql_json(self) -> FlaskResponse:
log_params = {
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
return self.sql_json_exec(request.json, log_params)
def sql_json_exec(
- self, query_params: dict, log_params: Optional[Dict[str, Any]] = None
- ):
+ self, query_params: Dict[str, Any], log_params: Optional[Dict[str, Any]] = None
+ ) -> FlaskResponse:
"""Runs arbitrary sql and returns data as json"""
# Collect Values
database_id: int = cast(int, query_params.get("database_id"))
@@ -2564,7 +2591,7 @@ class Superset(BaseSupersetView):
@has_access
@expose("/csv/<client_id>")
@event_logger.log_this
- def csv(self, client_id):
+ def csv(self, client_id: str) -> FlaskResponse:
"""Download the query results as csv."""
logger.info("Exporting CSV file [{}]".format(client_id))
query = db.session.query(Query).filter_by(client_id=client_id).one()
@@ -2587,7 +2614,7 @@ class Superset(BaseSupersetView):
blob, decode=not results_backend_use_msgpack
)
obj = _deserialize_results_payload(
- payload, query, results_backend_use_msgpack
+ payload, query, cast(bool, results_backend_use_msgpack)
)
columns = [c["name"] for c in obj["columns"]]
df = pd.DataFrame.from_records(obj["data"], columns=columns)
@@ -2622,8 +2649,8 @@ class Superset(BaseSupersetView):
@has_access
@expose("/fetch_datasource_metadata")
@event_logger.log_this
- def fetch_datasource_metadata(self):
- datasource_id, datasource_type = request.args.get("datasourceKey").split("__")
+ def fetch_datasource_metadata(self) -> FlaskResponse:
+ datasource_id, datasource_type = request.args["datasourceKey"].split("__")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
@@ -2636,17 +2663,16 @@ class Superset(BaseSupersetView):
return json_success(json.dumps(datasource.data))
@has_access_api
- @expose("/queries/<last_updated_ms>")
- def queries(self, last_updated_ms):
+ @expose("/queries/<int:last_updated_ms>")
+ def queries(self, last_updated_ms: int) -> FlaskResponse:
"""
Get the updated queries.
:param last_updated_ms: unix time, milliseconds
"""
- last_updated_ms_int = int(float(last_updated_ms)) if last_updated_ms else 0
- return self.queries_exec(last_updated_ms_int)
+ return self.queries_exec(last_updated_ms)
- def queries_exec(self, last_updated_ms_int: int):
+ def queries_exec(self, last_updated_ms: int) -> FlaskResponse:
stats_logger.incr("queries")
if not g.user.get_id():
return json_error_response(
@@ -2654,7 +2680,7 @@ class Superset(BaseSupersetView):
)
# UTC date time, same that is stored in the DB.
- last_updated_dt = utils.EPOCH + timedelta(seconds=last_updated_ms_int / 1000)
+ last_updated_dt = utils.EPOCH + timedelta(seconds=last_updated_ms / 1000)
sql_queries = (
db.session.query(Query)
@@ -2669,7 +2695,7 @@ class Superset(BaseSupersetView):
@has_access
@expose("/search_queries")
@event_logger.log_this
- def search_queries(self) -> Response:
+ def search_queries(self) -> FlaskResponse:
"""
Search for previously run sqllab queries. Used for Sqllab Query Search
page /superset/sqllab#search.
@@ -2730,14 +2756,14 @@ class Superset(BaseSupersetView):
)
@app.errorhandler(500)
- def show_traceback(self):
+ def show_traceback(self) -> FlaskResponse:
return (
render_template("superset/traceback.html", error_msg=get_error_msg()),
500,
)
@expose("/welcome")
- def welcome(self):
+ def welcome(self) -> FlaskResponse:
"""Personalized welcome page"""
if not g.user or not g.user.get_id():
return redirect(appbuilder.get_url_for_login)
@@ -2765,11 +2791,8 @@ class Superset(BaseSupersetView):
@has_access
@expose("/profile/<username>/")
- def profile(self, username):
+ def profile(self, username: str) -> FlaskResponse:
"""User profile page"""
- if not username and g.user:
- username = g.user.username
-
user = (
db.session.query(ab_models.User).filter_by(username=username).one_or_none()
)
@@ -2839,7 +2862,7 @@ class Superset(BaseSupersetView):
@has_access
@expose("/sqllab", methods=["GET", "POST"])
- def sqllab(self):
+ def sqllab(self) -> FlaskResponse:
"""SQL Editor"""
payload = {
"defaultDbId": config["SQLLAB_DEFAULT_DBID"],
@@ -2864,7 +2887,7 @@ class Superset(BaseSupersetView):
@api
@has_access_api
@expose("/schemas_access_for_csv_upload")
- def schemas_access_for_csv_upload(self):
+ def schemas_access_for_csv_upload(self) -> FlaskResponse:
"""
This method exposes an API endpoint to
get the schema access control settings for csv upload in this database
@@ -2872,7 +2895,7 @@ class Superset(BaseSupersetView):
if not request.args.get("db_id"):
return json_error_response("No database is allowed for your csv upload")
- db_id = int(request.args.get("db_id"))
+ db_id = int(request.args["db_id"])
database = db.session.query(models.Database).filter_by(id=db_id).one()
try:
schemas_allowed = database.get_schema_access_for_csv_upload()
@@ -2919,11 +2942,11 @@ class CssTemplateAsyncModelView(CssTemplateModelView):
@app.after_request
-def apply_http_headers(response: Response):
+def apply_http_headers(response: Response) -> Response:
"""Applies the configuration's http headers to all responses"""
# HTTP_HEADERS is deprecated, this provides backwards compatibility
- response.headers.extend(
+ response.headers.extend( # type: ignore
{**config["OVERRIDE_HTTP_HEADERS"], **config["HTTP_HEADERS"]}
)
diff --git a/superset/views/datasource.py b/superset/views/datasource.py
index b641ee1..2ce1102 100644
--- a/superset/views/datasource.py
+++ b/superset/views/datasource.py
@@ -17,7 +17,7 @@
import json
from collections import Counter
-from flask import request, Response
+from flask import request
from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
from sqlalchemy.orm.exc import NoResultFound
@@ -25,6 +25,7 @@ from sqlalchemy.orm.exc import NoResultFound
from superset import db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models.core import Database
+from superset.typing import FlaskResponse
from .base import api, BaseSupersetView, handle_api_exception, json_error_response
@@ -36,7 +37,7 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
- def save(self) -> Response:
+ def save(self) -> FlaskResponse:
data = request.form.get("data")
if not isinstance(data, str):
return json_error_response("Request missing data field.", status=500)
@@ -78,7 +79,7 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
- def get(self, datasource_type: str, datasource_id: int) -> Response:
+ def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
try:
orm_datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
@@ -95,7 +96,9 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
- def external_metadata(self, datasource_type: str, datasource_id: int) -> Response:
+ def external_metadata(
+ self, datasource_type: str, datasource_id: int
+ ) -> FlaskResponse:
"""Gets column info from the source system"""
if datasource_type == "druid":
datasource = ConnectorRegistry.get_datasource(
diff --git a/superset/views/filters.py b/superset/views/filters.py
index 3e4d85a..3594d21 100644
--- a/superset/views/filters.py
+++ b/superset/views/filters.py
@@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Any, cast, Optional
+
from flask_appbuilder.models.filters import BaseFilter
from flask_babel import lazy_gettext
from sqlalchemy import or_
+from sqlalchemy.orm import Query
from superset import security_manager
@@ -36,9 +39,9 @@ class FilterRelatedOwners(BaseFilter):
name = lazy_gettext("Owner")
arg_name = "owners"
- def apply(self, query, value):
+ def apply(self, query: Query, value: Optional[Any]) -> Query:
user_model = security_manager.user_model
- like_value = "%" + value + "%"
+ like_value = "%" + cast(str, value) + "%"
return query.filter(
or_(
# could be made to handle spaces between names more gracefully
diff --git a/superset/views/log/__init__.py b/superset/views/log/__init__.py
index b39d602..103632b 100644
--- a/superset/views/log/__init__.py
+++ b/superset/views/log/__init__.py
@@ -23,8 +23,8 @@ class LogMixin: # pylint: disable=too-few-public-methods
add_title = _("Add Log")
edit_title = _("Edit Log")
- list_columns = ("user", "action", "dttm")
- edit_columns = ("user", "action", "dttm", "json")
+ list_columns = ["user", "action", "dttm"]
+ edit_columns = ["user", "action", "dttm", "json"]
base_order = ("dttm", "desc")
label_columns = {
"user": _("User"),
diff --git a/superset/views/log/api.py b/superset/views/log/api.py
index d579eb0..f132c34 100644
--- a/superset/views/log/api.py
+++ b/superset/views/log/api.py
@@ -28,5 +28,5 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi):
class_permission_name = "LogModelView"
resource_name = "log"
allow_browser_login = True
- list_columns = ("user.username", "action", "dttm")
+ list_columns = ["user.username", "action", "dttm"]
show_columns = list_columns
diff --git a/superset/views/schedules.py b/superset/views/schedules.py
index 6f35a7a..68ae6ff 100644
--- a/superset/views/schedules.py
+++ b/superset/views/schedules.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import enum
-from typing import Optional, Type
+from typing import Type
import simplejson as json
from croniter import croniter
@@ -24,7 +24,7 @@ from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access
from flask_babel import lazy_gettext as _
-from wtforms import BooleanField, StringField
+from wtforms import BooleanField, Form, StringField
from superset import db, security_manager
from superset.constants import RouteMethod
@@ -37,6 +37,7 @@ from superset.models.schedules import (
)
from superset.models.slice import Slice
from superset.tasks.schedules import schedule_email_report
+from superset.typing import FlaskResponse
from superset.utils.core import get_email_address_list, json_iso_dttm_ser
from superset.views.core import json_success
@@ -48,8 +49,14 @@ class EmailScheduleView(
): # pylint: disable=too-many-ancestors
include_route_methods = RouteMethod.CRUD_SET
_extra_data = {"test_email": False, "test_email_recipients": None}
- schedule_type: Optional[str] = None
- schedule_type_model: Optional[Type] = None
+
+ @property
+ def schedule_type(self) -> str:
+ raise NotImplementedError()
+
+ @property
+ def schedule_type_model(self) -> Type:
+ raise NotImplementedError()
page_size = 20
@@ -87,7 +94,7 @@ class EmailScheduleView(
edit_form_extra_fields = add_form_extra_fields
- def process_form(self, form, is_created):
+ def process_form(self, form: Form, is_created: bool) -> None:
if form.test_email_recipients.data:
test_email_recipients = form.test_email_recipients.data.strip()
else:
@@ -95,7 +102,7 @@ class EmailScheduleView(
self._extra_data["test_email"] = form.test_email.data
self._extra_data["test_email_recipients"] = test_email_recipients
- def pre_add(self, item):
+ def pre_add(self, item: "EmailScheduleView") -> None:
try:
recipients = get_email_address_list(item.recipients)
item.recipients = ", ".join(recipients)
@@ -106,10 +113,10 @@ class EmailScheduleView(
if not croniter.is_valid(item.crontab):
raise SupersetException("Invalid crontab format")
- def pre_update(self, item):
+ def pre_update(self, item: "EmailScheduleView") -> None:
self.pre_add(item)
- def post_add(self, item):
+ def post_add(self, item: "EmailScheduleView") -> None:
# Schedule a test mail if the user requested for it.
if self._extra_data["test_email"]:
recipients = self._extra_data["test_email_recipients"] or item.recipients
@@ -122,12 +129,12 @@ class EmailScheduleView(
if item.active:
flash("Schedule changes will get applied in one hour", "warning")
- def post_update(self, item):
+ def post_update(self, item: "EmailScheduleView") -> None:
self.post_add(item)
@has_access
@expose("/fetch/<int:item_id>/", methods=["GET"])
- def fetch_schedules(self, item_id):
+ def fetch_schedules(self, item_id: int) -> FlaskResponse:
query = db.session.query(self.datamodel.obj)
query = query.join(self.schedule_type_model).filter(
@@ -147,7 +154,9 @@ class EmailScheduleView(
info[col] = info[col].username
info["user"] = schedule.user.username
- info[self.schedule_type] = getattr(schedule, self.schedule_type).id
+ info[self.schedule_type] = getattr( # type: ignore
+ schedule, self.schedule_type
+ ).id
schedules.append(info)
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))
@@ -208,7 +217,7 @@ class DashboardEmailScheduleView(
"delivery_type": _("Delivery Type"),
}
- def pre_add(self, item):
+ def pre_add(self, item: "DashboardEmailScheduleView") -> None:
if item.dashboard is None:
raise SupersetException("Dashboard is mandatory")
super(DashboardEmailScheduleView, self).pre_add(item)
@@ -269,7 +278,7 @@ class SliceEmailScheduleView(EmailScheduleView): # pylint: disable=too-many-anc
"email_format": _("Email Format"),
}
- def pre_add(self, item):
+ def pre_add(self, item: "SliceEmailScheduleView") -> None:
if item.slice is None:
raise SupersetException("Slice is mandatory")
super(SliceEmailScheduleView, self).pre_add(item)
diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py
index 3646965..3476bb3 100644
--- a/superset/views/sql_lab.py
+++ b/superset/views/sql_lab.py
@@ -27,6 +27,7 @@ from flask_sqlalchemy import BaseQuery
from superset import db, get_feature_flags, security_manager
from superset.constants import RouteMethod
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
+from superset.typing import FlaskResponse
from superset.utils import core as utils
from .base import (
@@ -120,15 +121,15 @@ class SavedQueryView(
show_template = "superset/models/savedquery/show.html"
- def pre_add(self, item):
+ def pre_add(self, item: "SavedQueryView") -> None:
item.user = g.user
- def pre_update(self, item):
+ def pre_update(self, item: "SavedQueryView") -> None:
self.pre_add(item)
@has_access
@expose("show/<pk>")
- def show(self, pk):
+ def show(self, pk: int) -> FlaskResponse:
pk = self._deserialize_pk_if_composite(pk)
widgets = self._show(pk)
query = self.datamodel.get(pk).to_json()
@@ -168,18 +169,18 @@ class SavedQueryViewApi(SavedQueryView): # pylint: disable=too-many-ancestors
@has_access_api
@expose("show/<pk>")
- def show(self, pk):
+ def show(self, pk: int) -> FlaskResponse:
return super().show(pk)
-def _get_owner_id(tab_state_id):
+def _get_owner_id(tab_state_id: int) -> int:
return db.session.query(TabState.user_id).filter_by(id=tab_state_id).scalar()
class TabStateView(BaseSupersetView):
@has_access_api
@expose("/", methods=["POST"])
- def post(self): # pylint: disable=no-self-use
+ def post(self) -> FlaskResponse: # pylint: disable=no-self-use
query_editor = json.loads(request.form["queryEditor"])
tab_state = TabState(
user_id=g.user.get_id(),
@@ -201,7 +202,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("/<int:tab_state_id>", methods=["DELETE"])
- def delete(self, tab_state_id): # pylint: disable=no-self-use
+ def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@@ -216,7 +217,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("/<int:tab_state_id>", methods=["GET"])
- def get(self, tab_state_id): # pylint: disable=no-self-use
+ def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@@ -229,7 +230,9 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>/activate", methods=["POST"])
- def activate(self, tab_state_id): # pylint: disable=no-self-use
+ def activate( # pylint: disable=no-self-use
+ self, tab_state_id: int
+ ) -> FlaskResponse:
owner_id = _get_owner_id(tab_state_id)
if owner_id is None:
return Response(status=404)
@@ -246,7 +249,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>", methods=["PUT"])
- def put(self, tab_state_id): # pylint: disable=no-self-use
+ def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@@ -257,7 +260,9 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>/migrate_query", methods=["POST"])
- def migrate_query(self, tab_state_id): # pylint: disable=no-self-use
+ def migrate_query( # pylint: disable=no-self-use
+ self, tab_state_id: int
+ ) -> FlaskResponse:
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@@ -270,7 +275,9 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>/query/<client_id>", methods=["DELETE"])
- def delete_query(self, tab_state_id, client_id): # pylint: disable=no-self-use
+ def delete_query( # pylint: disable=no-self-use
+ self, tab_state_id: str, client_id: str
+ ) -> FlaskResponse:
db.session.query(Query).filter_by(
client_id=client_id, user_id=g.user.get_id(), sql_editor_id=tab_state_id
).delete(synchronize_session=False)
@@ -281,7 +288,7 @@ class TabStateView(BaseSupersetView):
class TableSchemaView(BaseSupersetView):
@has_access_api
@expose("/", methods=["POST"])
- def post(self): # pylint: disable=no-self-use
+ def post(self) -> FlaskResponse: # pylint: disable=no-self-use
table = json.loads(request.form["table"])
# delete any existing table schema
@@ -306,7 +313,9 @@ class TableSchemaView(BaseSupersetView):
@has_access_api
@expose("/<int:table_schema_id>", methods=["DELETE"])
- def delete(self, table_schema_id): # pylint: disable=no-self-use
+ def delete( # pylint: disable=no-self-use
+ self, table_schema_id: int
+ ) -> FlaskResponse:
db.session.query(TableSchema).filter(TableSchema.id == table_schema_id).delete(
synchronize_session=False
)
@@ -315,7 +324,9 @@ class TableSchemaView(BaseSupersetView):
@has_access_api
@expose("/<int:table_schema_id>/expanded", methods=["POST"])
- def expanded(self, table_schema_id): # pylint: disable=no-self-use
+ def expanded( # pylint: disable=no-self-use
+ self, table_schema_id: int
+ ) -> FlaskResponse:
payload = json.loads(request.form["expanded"])
(
db.session.query(TableSchema)
@@ -332,6 +343,6 @@ class SqlLab(BaseSupersetView):
@expose("/my_queries/")
@has_access
- def my_queries(self): # pylint: disable=no-self-use
+ def my_queries(self) -> FlaskResponse: # pylint: disable=no-self-use
"""Assigns a list of found users to the given role."""
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.id))
diff --git a/superset/views/tags.py b/superset/views/tags.py
index e12df2a..2bcc0c7 100644
--- a/superset/views/tags.py
+++ b/superset/views/tags.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
+from typing import Any, Dict, List
+
import simplejson as json
from flask import request, Response
from flask_appbuilder import expose
@@ -29,11 +31,12 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.models.tags import ObjectTypes, Tag, TaggedObject, TagTypes
+from superset.typing import FlaskResponse
from .base import BaseSupersetView, json_success
-def process_template(content):
+def process_template(content: str) -> str:
env = SandboxedEnvironment()
template = env.from_string(content)
context = {
@@ -46,7 +49,7 @@ def process_template(content):
class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/suggestions/", methods=["GET"])
- def suggestions(self): # pylint: disable=no-self-use
+ def suggestions(self) -> FlaskResponse: # pylint: disable=no-self-use
query = (
db.session.query(TaggedObject)
.join(Tag)
@@ -60,7 +63,9 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/<object_type:object_type>/<int:object_id>/", methods=["GET"])
- def get(self, object_type, object_id): # pylint: disable=no-self-use
+ def get( # pylint: disable=no-self-use
+ self, object_type: ObjectTypes, object_id: int
+ ) -> FlaskResponse:
"""List all tags a given object has."""
if object_id == 0:
return json_success(json.dumps([]))
@@ -76,7 +81,9 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/<object_type:object_type>/<int:object_id>/", methods=["POST"])
- def post(self, object_type, object_id): # pylint: disable=no-self-use
+ def post( # pylint: disable=no-self-use
+ self, object_type: ObjectTypes, object_id: int
+ ) -> FlaskResponse:
"""Add new tags to an object."""
if object_id == 0:
return Response(status=404)
@@ -104,7 +111,9 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/<object_type:object_type>/<int:object_id>/", methods=["DELETE"])
- def delete(self, object_type, object_id): # pylint: disable=no-self-use
+ def delete( # pylint: disable=no-self-use
+ self, object_type: ObjectTypes, object_id: int
+ ) -> FlaskResponse:
"""Remove tags from an object."""
tag_names = request.get_json(force=True)
if not tag_names:
@@ -123,7 +132,7 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tagged_objects/", methods=["GET", "POST"])
- def tagged_objects(self): # pylint: disable=no-self-use
+ def tagged_objects(self) -> FlaskResponse: # pylint: disable=no-self-use
tags = [
process_template(tag)
for tag in request.args.get("tags", "").split(",")
@@ -135,7 +144,7 @@ class TagView(BaseSupersetView):
# filter types
types = [type_ for type_ in request.args.get("types", "").split(",") if type_]
- results = []
+ results: List[Dict[str, Any]] = []
# dashboards
if not types or "dashboard" in types:
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 5ed7e48..4edd2e7 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -16,11 +16,12 @@
# under the License.
from collections import defaultdict
from datetime import date
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple
from urllib import parse
import simplejson as json
from flask import g, request
+from flask_appbuilder.security.sqla.models import User
import superset.models.core as models
from superset import app, db, is_feature_enabled
@@ -29,7 +30,9 @@ from superset.exceptions import SupersetException
from superset.legacy import update_time_range
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
+from superset.typing import FormData
from superset.utils.core import QueryStatus, TimeRangeEndpoint
+from superset.viz import BaseViz
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
from superset import viz_sip38 as viz # type: ignore
@@ -42,7 +45,7 @@ if not app.config["ENABLE_JAVASCRIPT_CONTROLS"]:
FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"]
-def bootstrap_user_data(user, include_perms=False):
+def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, Any]:
if user.is_anonymous:
return {}
payload = {
@@ -63,7 +66,9 @@ def bootstrap_user_data(user, include_perms=False):
return payload
-def get_permissions(user):
+def get_permissions(
+ user: User,
+) -> Tuple[Dict[str, List[List[str]]], DefaultDict[str, Set[str]]]:
if not user.roles:
raise AttributeError("User object does not have roles")
@@ -86,11 +91,8 @@ def get_permissions(user):
def get_viz(
- form_data: Dict[str, Any],
- datasource_type: str,
- datasource_id: int,
- force: bool = False,
-):
+ form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False,
+) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
@@ -158,9 +160,7 @@ def get_form_data(
def get_datasource_info(
- datasource_id: Optional[int],
- datasource_type: Optional[str],
- form_data: Dict[str, Any],
+ datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData,
) -> Tuple[int, Optional[str]]:
"""
Compatibility layer for handling of datasource info
@@ -222,9 +222,7 @@ def apply_display_max_row_limit(
def get_time_range_endpoints(
- form_data: Dict[str, Any],
- slc: Optional[Slice] = None,
- slice_id: Optional[int] = None,
+ form_data: FormData, slc: Optional[Slice] = None, slice_id: Optional[int] = None,
) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]:
"""
Get the slice aware time range endpoints from the form-data falling back to the SQL
diff --git a/superset/viz.py b/superset/viz.py
index 3fdd81e..d53dcf2 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -525,7 +525,7 @@ class BaseViz:
has_error = (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
- or len(payload.get("errors") or []) > 0
+ or bool(payload.get("errors"))
)
return self.json_dumps(payload), has_error
diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py
index 78d15d9..32df001 100644
--- a/superset/viz_sip38.py
+++ b/superset/viz_sip38.py
@@ -56,7 +56,7 @@ from superset.exceptions import (
SpatialException,
)
from superset.models.helpers import QueryResult
-from superset.typing import VizData
+from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
from superset.utils.core import (
DTTM_ALIAS,
@@ -251,7 +251,7 @@ class BaseViz:
df = df[min_periods:]
return df
- def get_samples(self):
+ def get_samples(self) -> List[Dict[str, Any]]:
query_obj = self.query_obj()
query_obj.update(
{
@@ -452,7 +452,7 @@ class BaseViz:
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
- def get_payload(self, query_obj=None):
+ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload:
"""Returns a payload of metadata and data"""
self.run_extra_queries()
payload = self.get_df_payload(query_obj)
@@ -464,7 +464,9 @@ class BaseViz:
del payload["df"]
return payload
- def get_df_payload(self, query_obj=None, **kwargs):
+ def get_df_payload(
+ self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
if not query_obj:
query_obj = self.query_obj()
@@ -559,11 +561,11 @@ class BaseViz:
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
- def payload_json_and_has_error(self, payload):
+ def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
has_error = (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
- or len(payload.get("errors")) > 0
+ or len(payload.get("errors", [])) > 0
)
return self.json_dumps(payload), has_error
@@ -578,7 +580,7 @@ class BaseViz:
}
return content
- def get_csv(self):
+ def get_csv(self) -> Optional[str]:
df = self.get_df()
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])