You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/06/07 15:54:03 UTC
[incubator-superset] branch master updated: style(mypy):
Spit-and-polish pass (#10001)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 91517a5 style(mypy): Spit-and-polish pass (#10001)
91517a5 is described below
commit 91517a56a3bcacd4c8f2dca233d8df248bfca10e
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Sun Jun 7 08:53:46 2020 -0700
style(mypy): Spit-and-polish pass (#10001)
Co-authored-by: John Bodley <jo...@airbnb.com>
---
setup.cfg | 4 +-
superset/app.py | 4 +-
superset/charts/commands/create.py | 4 +-
superset/charts/commands/update.py | 4 +-
superset/common/query_object.py | 13 +++---
superset/config.py | 7 +--
superset/connectors/base/models.py | 13 +++---
superset/connectors/connector_registry.py | 5 ++-
superset/connectors/druid/models.py | 51 ++++++++++------------
superset/connectors/sqla/models.py | 15 ++++---
superset/dao/base.py | 8 ++--
superset/dashboards/commands/create.py | 4 +-
superset/dashboards/commands/update.py | 4 +-
superset/datasets/commands/create.py | 4 +-
superset/datasets/commands/update.py | 10 ++---
superset/datasets/dao.py | 12 ++---
superset/db_engine_specs/base.py | 12 ++---
superset/db_engine_specs/bigquery.py | 2 +-
superset/db_engine_specs/exasol.py | 2 +-
superset/db_engine_specs/hive.py | 6 +--
superset/db_engine_specs/mssql.py | 2 +-
superset/db_engine_specs/postgres.py | 2 +-
superset/db_engine_specs/presto.py | 19 ++++----
superset/extensions.py | 6 ++-
superset/models/core.py | 9 ++--
superset/models/dashboard.py | 8 ++--
superset/models/helpers.py | 2 +-
superset/models/slice.py | 2 +-
superset/models/sql_types/presto_sql_types.py | 12 ++---
superset/queries/filters.py | 4 +-
superset/result_set.py | 6 +--
superset/security/manager.py | 15 +++----
superset/sql_lab.py | 7 +--
superset/tasks/celery_app.py | 2 +-
superset/tasks/schedules.py | 10 +++--
superset/utils/cache.py | 7 +--
superset/utils/core.py | 18 +++++---
.../utils/dashboard_filter_scopes_converter.py | 11 ++---
superset/utils/decorators.py | 4 +-
superset/utils/import_datasource.py | 8 ++--
superset/utils/log.py | 4 +-
superset/utils/logging_configurator.py | 2 +-
superset/utils/pandas_postprocessing.py | 6 +--
superset/utils/screenshots.py | 14 +++---
superset/views/base.py | 14 +++---
superset/views/base_api.py | 2 +-
superset/views/base_schemas.py | 4 +-
superset/views/core.py | 14 +++---
superset/views/database/api.py | 10 ++---
superset/views/database/decorators.py | 2 +-
superset/views/schedules.py | 8 ++--
superset/views/sql_lab.py | 4 +-
superset/views/utils.py | 12 ++---
superset/viz.py | 4 +-
tests/base_tests.py | 10 +++--
tests/superset_test_config_thumbnails.py | 2 +-
56 files changed, 243 insertions(+), 207 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index 81c7ed2..93e33af 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -50,10 +50,12 @@ multi_line_output = 3
order_by_type = false
[mypy]
+disallow_any_generics = true
ignore_missing_imports = true
no_implicit_optional = true
+warn_unused_ignores = true
-[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...]
+[mypy-superset.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
diff --git a/superset/app.py b/superset/app.py
index 18165ed..b36df75 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -80,7 +80,7 @@ class SupersetAppInitializer:
self.flask_app = app
self.config = app.config
- self.manifest: dict = {}
+ self.manifest: Dict[Any, Any] = {}
def pre_init(self) -> None:
"""
@@ -542,7 +542,7 @@ class SupersetAppInitializer:
self.app = app
def __call__(
- self, environ: Dict[str, Any], start_response: Callable
+ self, environ: Dict[str, Any], start_response: Callable[..., Any]
) -> Any:
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py
index 8e7dcb7..1425396 100644
--- a/superset/charts/commands/create.py
+++ b/superset/charts/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(BaseCommand):
- def __init__(self, user: User, data: Dict):
+ def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py
index 21c236c..70055bf 100644
--- a/superset/charts/commands/update.py
+++ b/superset/charts/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class UpdateChartCommand(BaseCommand):
- def __init__(self, user: User, model_id: int, data: Dict):
+ def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id
self._properties = data.copy()
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 3c5b778..188d0b3 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -26,6 +26,7 @@ from pandas import DataFrame
from superset import app, is_feature_enabled
from superset.exceptions import QueryObjectValidationError
+from superset.typing import Metric
from superset.utils import core as utils, pandas_postprocessing
from superset.views.utils import get_time_range_endpoints
@@ -67,11 +68,11 @@ class QueryObject:
row_limit: int
filter: List[Dict[str, Any]]
timeseries_limit: int
- timeseries_limit_metric: Optional[Dict]
+ timeseries_limit_metric: Optional[Metric]
order_desc: bool
- extras: Dict
+ extras: Dict[str, Any]
columns: List[str]
- orderby: List[List]
+ orderby: List[List[str]]
post_processing: List[Dict[str, Any]]
def __init__(
@@ -85,11 +86,11 @@ class QueryObject:
is_timeseries: bool = False,
timeseries_limit: int = 0,
row_limit: int = app.config["ROW_LIMIT"],
- timeseries_limit_metric: Optional[Dict] = None,
+ timeseries_limit_metric: Optional[Metric] = None,
order_desc: bool = True,
- extras: Optional[Dict] = None,
+ extras: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None,
- orderby: Optional[List[List]] = None,
+ orderby: Optional[List[List[str]]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
):
diff --git a/superset/config.py b/superset/config.py
index 35dbbf8..6da24b4 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from cachelib.base import BaseCache
from celery.schedules import crontab
from dateutil import tz
+from flask import Blueprint
from flask_appbuilder.security.manager import AUTH_DB
from superset.jinja_context import ( # pylint: disable=unused-import
@@ -421,7 +422,7 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
]
)
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
-ADDITIONAL_MIDDLEWARE: List[Callable] = []
+ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
# 1) https://docs.python-guide.org/writing/logging/
# 2) https://docs.python.org/2/library/logging.config.html
@@ -624,7 +625,7 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
# SQL Lab. The existing context gets updated with this dictionary,
# meaning values for existing keys get overwritten by the content of this
# dictionary.
-JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
+JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
# A dictionary of macro template processors that gets merged into global
# template processors. The existing template processors get updated with this
@@ -684,7 +685,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app
-BLUEPRINTS: List[Callable] = []
+BLUEPRINTS: List[Blueprint] = []
# Provide a callable that receives a tracking_url and returns another
# URL. This is used to translate internal Hadoop job tracker URL
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 0533aa1..fb2d5eb 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
-from typing import Any, Dict, Hashable, List, Optional, Type
+from typing import Any, Dict, Hashable, List, Optional, Type, Union
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
@@ -64,12 +64,12 @@ class BaseDatasource(
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
@property
- def column_class(self) -> Type:
+ def column_class(self) -> Type["BaseColumn"]:
# link to derivative of BaseColumn
raise NotImplementedError()
@property
- def metric_class(self) -> Type:
+ def metric_class(self) -> Type["BaseMetric"]:
# link to derivative of BaseMetric
raise NotImplementedError()
@@ -368,7 +368,7 @@ class BaseDatasource(
"""
raise NotImplementedError()
- def values_for_column(self, column_name: str, limit: int = 10000) -> List:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@@ -389,7 +389,10 @@ class BaseDatasource(
@staticmethod
def get_fk_many_from_list(
- object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str,
+ object_list: List[Any],
+ fkmany: List[Column],
+ fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
+ key_attr: str,
) -> List[Column]: # pylint: disable=too-many-locals
"""Update ORM one-to-many list from object list
diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py
index 4097066..3b11973 100644
--- a/superset/connectors/connector_registry.py
+++ b/superset/connectors/connector_registry.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from collections import OrderedDict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from sqlalchemy import or_
@@ -22,6 +21,8 @@ from sqlalchemy.orm import Session, subqueryload
if TYPE_CHECKING:
# pylint: disable=unused-import
+ from collections import OrderedDict
+
from superset.models.core import Database
from superset.connectors.base.models import BaseDatasource
@@ -32,7 +33,7 @@ class ConnectorRegistry:
sources: Dict[str, Type["BaseDatasource"]] = {}
@classmethod
- def register_sources(cls, datasource_config: OrderedDict) -> None:
+ def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
for module_name, class_names in datasource_config.items():
class_names = [str(s) for s in class_names]
module_obj = __import__(module_name, fromlist=class_names)
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index 50f1637..4de56c9 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -24,18 +24,7 @@ from copy import deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
from multiprocessing.pool import ThreadPool
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- Iterable,
- List,
- Optional,
- Set,
- Tuple,
- Union,
-)
+from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
import pandas as pd
import sqlalchemy as sa
@@ -173,7 +162,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
return self.__repr__()
@property
- def data(self) -> Dict:
+ def data(self) -> Dict[str, Any]:
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
@staticmethod
@@ -354,7 +343,7 @@ class DruidColumn(Model, BaseColumn):
return self.dimension_spec_json
@property
- def dimension_spec(self) -> Optional[Dict]:
+ def dimension_spec(self) -> Optional[Dict[str, Any]]:
if self.dimension_spec_json:
return json.loads(self.dimension_spec_json)
return None
@@ -438,7 +427,7 @@ class DruidMetric(Model, BaseMetric):
return self.json
@property
- def json_obj(self) -> Dict:
+ def json_obj(self) -> Dict[str, Any]:
try:
obj = json.loads(self.json)
except Exception:
@@ -614,7 +603,7 @@ class DruidDatasource(Model, BaseDatasource):
name = escape(self.datasource_name)
return Markup(f'<a href="{url}">{name}</a>')
- def get_metric_obj(self, metric_name: str) -> Dict:
+ def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
@classmethod
@@ -705,7 +694,11 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def sync_to_db_from_config(
- cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
+ cls,
+ druid_config: Dict[str, Any],
+ user: User,
+ cluster: DruidCluster,
+ refresh: bool = True,
) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session
@@ -901,7 +894,7 @@ class DruidDatasource(Model, BaseDatasource):
return postagg_metrics
@staticmethod
- def recursive_get_fields(_conf: Dict) -> List[str]:
+ def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]:
_type = _conf.get("type")
_field = _conf.get("field")
_fields = _conf.get("fields")
@@ -957,8 +950,8 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod
def metrics_and_post_aggs(
- metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric],
- ) -> Tuple[OrderedDict, OrderedDict]:
+ metrics: List[Metric], metrics_dict: Dict[str, DruidMetric],
+ ) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]:
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
@@ -987,7 +980,7 @@ class DruidDatasource(Model, BaseDatasource):
)
return aggs, post_aggs
- def values_for_column(self, column_name: str, limit: int = 10000) -> List:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Retrieve some values for the given column"""
logger.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
@@ -1079,8 +1072,10 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod
def get_aggregations(
- metrics_dict: Dict, saved_metrics: Set[str], adhoc_metrics: List[Dict] = []
- ) -> OrderedDict:
+ metrics_dict: Dict[str, Any],
+ saved_metrics: Set[str],
+ adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
+ ) -> "OrderedDict[str, Any]":
"""
Returns a dictionary of aggregation metric names to aggregation json objects
@@ -1089,7 +1084,9 @@ class DruidDatasource(Model, BaseDatasource):
:param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations
"""
- aggregations: OrderedDict = OrderedDict()
+ if not adhoc_metrics:
+ adhoc_metrics = []
+ aggregations = OrderedDict()
invalid_metric_names = []
for metric_name in saved_metrics:
if metric_name in metrics_dict:
@@ -1115,7 +1112,7 @@ class DruidDatasource(Model, BaseDatasource):
def get_dimensions(
self, columns: List[str], columns_dict: Dict[str, DruidColumn]
- ) -> List[Union[str, Dict]]:
+ ) -> List[Union[str, Dict[str, Any]]]:
dimensions = []
columns = [col for col in columns if col in columns_dict]
for column_name in columns:
@@ -1433,7 +1430,7 @@ class DruidDatasource(Model, BaseDatasource):
df[columns] = df[columns].fillna(NULL_STRING).astype("unicode")
return df
- def query(self, query_obj: Dict) -> QueryResult:
+ def query(self, query_obj: QueryObjectDict) -> QueryResult:
qry_start_dttm = datetime.now()
client = self.cluster.get_pydruid_client()
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
@@ -1583,7 +1580,7 @@ class DruidDatasource(Model, BaseDatasource):
dimension=col, value=eq, extraction_function=extraction_fn
)
elif is_list_target:
- eq = cast(list, eq)
+ eq = cast(List[Any], eq)
fields = []
# ignore the filter if it has no value
if not len(eq):
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 0e91bd2..4e93d5f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -597,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
)
@property
- def data(self) -> Dict:
+ def data(self) -> Dict[str, Any]:
d = super().data
if self.type == "table":
grains = self.database.grains() or []
@@ -684,7 +684,9 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()
- def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
+ def adhoc_metric_to_sqla(
+ self, metric: Dict[str, Any], cols: Dict[str, Any]
+ ) -> Optional[Column]:
"""
Turn an adhoc metric into a sqlalchemy column.
@@ -804,7 +806,7 @@ class SqlaTable(Model, BaseDatasource):
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
select_exprs: List[Column] = []
- groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
+ groupby_exprs_sans_timestamp = OrderedDict()
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
# dedup columns while preserving order
@@ -874,7 +876,7 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
where_clause_and = []
- having_clause_and: List = []
+ having_clause_and = []
for flt in filter: # type: ignore
if not all([flt.get(s) for s in ["col", "op"]]):
@@ -1082,7 +1084,10 @@ class SqlaTable(Model, BaseDatasource):
return ob
def _get_top_groups(
- self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
+ self,
+ df: pd.DataFrame,
+ dimensions: List[str],
+ groupby_exprs: "OrderedDict[str, Any]",
) -> ColumnElement:
groups = []
for unused, row in df.iterrows():
diff --git a/superset/dao/base.py b/superset/dao/base.py
index 020feed..59791ff 100644
--- a/superset/dao/base.py
+++ b/superset/dao/base.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
@@ -75,7 +75,7 @@ class BaseDAO:
return query.all()
@classmethod
- def create(cls, properties: Dict, commit: bool = True) -> Model:
+ def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
"""
Generic for creating models
:raises: DAOCreateFailedError
@@ -95,7 +95,9 @@ class BaseDAO:
return model
@classmethod
- def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
+ def update(
+ cls, model: Model, properties: Dict[str, Any], commit: bool = True
+ ) -> Model:
"""
Generic update a model
:raises: DAOCreateFailedError
diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py
index 0aa1241..8376369 100644
--- a/superset/dashboards/commands/create.py
+++ b/superset/dashboards/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(BaseCommand):
- def __init__(self, user: User, data: Dict):
+ def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py
index 7746b7e..54c5de1 100644
--- a/superset/dashboards/commands/update.py
+++ b/superset/dashboards/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class UpdateDashboardCommand(BaseCommand):
- def __init__(self, user: User, model_id: int, data: Dict):
+ def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id
self._properties = data.copy()
diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py
index 3114a4f..436fdd2 100644
--- a/superset/datasets/commands/create.py
+++ b/superset/datasets/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class CreateDatasetCommand(BaseCommand):
- def __init__(self, user: User, data: Dict):
+ def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py
index c7f70dd..14cc087 100644
--- a/superset/datasets/commands/update.py
+++ b/superset/datasets/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from collections import Counter
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@@ -48,7 +48,7 @@ logger = logging.getLogger(__name__)
class UpdateDatasetCommand(BaseCommand):
- def __init__(self, user: User, model_id: int, data: Dict):
+ def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id
self._properties = data.copy()
@@ -111,7 +111,7 @@ class UpdateDatasetCommand(BaseCommand):
raise exception
def _validate_columns(
- self, columns: List[Dict], exceptions: List[ValidationError]
+ self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
) -> None:
# Validate duplicates on data
if self._get_duplicates(columns, "column_name"):
@@ -133,7 +133,7 @@ class UpdateDatasetCommand(BaseCommand):
exceptions.append(DatasetColumnsExistsValidationError())
def _validate_metrics(
- self, metrics: List[Dict], exceptions: List[ValidationError]
+ self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
) -> None:
if self._get_duplicates(metrics, "metric_name"):
exceptions.append(DatasetMetricsDuplicateValidationError())
@@ -152,7 +152,7 @@ class UpdateDatasetCommand(BaseCommand):
exceptions.append(DatasetMetricsExistsValidationError())
@staticmethod
- def _get_duplicates(data: List[Dict], key: str) -> List[str]:
+ def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
duplicates = [
name
for name, count in Counter([item[key] for item in data]).items()
diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py
index 5dfe4ef..ef20a69 100644
--- a/superset/datasets/dao.py
+++ b/superset/datasets/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
@@ -116,7 +116,7 @@ class DatasetDAO(BaseDAO):
@classmethod
def update(
- cls, model: SqlaTable, properties: Dict, commit: bool = True
+ cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlaTable]:
"""
Updates a Dataset model on the metadata DB
@@ -151,13 +151,13 @@ class DatasetDAO(BaseDAO):
@classmethod
def update_column(
- cls, model: TableColumn, properties: Dict, commit: bool = True
+ cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]:
return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod
def create_column(
- cls, properties: Dict, commit: bool = True
+ cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]:
"""
Creates a Dataset model on the metadata DB
@@ -166,13 +166,13 @@ class DatasetDAO(BaseDAO):
@classmethod
def update_metric(
- cls, model: SqlMetric, properties: Dict, commit: bool = True
+ cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]:
return DatasetMetricDAO.update(model, properties, commit=commit)
@classmethod
def create_metric(
- cls, properties: Dict, commit: bool = True
+ cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]:
"""
Creates a Dataset model on the metadata DB
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index a593f59..7b0d537 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -151,7 +151,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
# default matching patterns for identifying column types
- db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = {
+ db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = {
utils.DbColumnType.NUMERIC: (
re.compile(r".*DOUBLE.*", re.IGNORECASE),
re.compile(r".*FLOAT.*", re.IGNORECASE),
@@ -296,7 +296,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return select_exprs
@classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
+ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
"""
:param cursor: Cursor instance
@@ -311,8 +311,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def expand_data(
- cls, columns: List[dict], data: List[dict]
- ) -> Tuple[List[dict], List[dict], List[dict]]:
+ cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
+ ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
Some engines support expanding nested fields. See implementation in Presto
spec for details.
@@ -645,7 +645,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: Optional[str],
database: "Database",
query: Select,
- columns: Optional[List] = None,
+ columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
"""
Add a where clause to a query to reference only the most recent partition
@@ -925,7 +925,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return []
@staticmethod
- def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]:
+ def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
"""
Convert pyodbc.Row objects from `fetch_data` to tuples.
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 992b5fe..3091d65 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -83,7 +83,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return None
@classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
+ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Support type BigQuery Row, introduced here PR #4071
# google.cloud.bigquery.table.Row
diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py
index 480a8c2..23449f0 100644
--- a/superset/db_engine_specs/exasol.py
+++ b/superset/db_engine_specs/exasol.py
@@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
}
@classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
+ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 3fb09ef..63c27bd 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -93,7 +93,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
@classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
+ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
import pyhive
from TCLIService import ttypes
@@ -304,7 +304,7 @@ class HiveEngineSpec(PrestoEngineSpec):
schema: Optional[str],
database: "Database",
query: Select,
- columns: Optional[List] = None,
+ columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
@@ -323,7 +323,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
- def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
+ def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index fde69b3..45c2f23 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -66,7 +66,7 @@ class MssqlEngineSpec(BaseEngineSpec):
return None
@classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
+ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index b5f1b2c..c5e4221 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
}
@classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
+ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
cursor.tzinfo_factory = FixedOffsetTimezone
if not cursor.description:
return []
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 9bc9307..e8c9603 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -164,7 +164,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return [row[0] for row in results]
@classmethod
- def _create_column_info(cls, name: str, data_type: str) -> dict:
+ def _create_column_info(cls, name: str, data_type: str) -> Dict[str, Any]:
"""
Create column info object
:param name: column name
@@ -213,7 +213,10 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
- cls, parent_column_name: str, parent_data_type: str, result: List[dict]
+ cls,
+ parent_column_name: str,
+ parent_data_type: str,
+ result: List[Dict[str, Any]],
) -> None:
"""
Parse a row or array column
@@ -322,7 +325,7 @@ class PrestoEngineSpec(BaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
- result: List[dict] = []
+ result: List[Dict[str, Any]] = []
for column in columns:
try:
# parse column if it is a row or array
@@ -361,7 +364,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
- def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
+ def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@@ -561,8 +564,8 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def expand_data( # pylint: disable=too-many-locals
- cls, columns: List[dict], data: List[dict]
- ) -> Tuple[List[dict], List[dict], List[dict]]:
+ cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
+ ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
We do not immediately display rows and arrays clearly in the data grid. This
method separates out nested fields and data values to help clearly display
@@ -590,7 +593,7 @@ class PrestoEngineSpec(BaseEngineSpec):
# process each column, unnesting ARRAY types and
# expanding ROW types into new columns
to_process = deque((column, 0) for column in columns)
- all_columns: List[dict] = []
+ all_columns: List[Dict[str, Any]] = []
expanded_columns = []
current_array_level = None
while to_process:
@@ -843,7 +846,7 @@ class PrestoEngineSpec(BaseEngineSpec):
schema: Optional[str],
database: "Database",
query: Select,
- columns: Optional[List] = None,
+ columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
diff --git a/superset/extensions.py b/superset/extensions.py
index f321046..a0dad81 100644
--- a/superset/extensions.py
+++ b/superset/extensions.py
@@ -95,7 +95,9 @@ class UIManifestProcessor:
self.parse_manifest_json()
@app.context_processor
- def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable
+ def get_manifest() -> Dict[ # pylint: disable=unused-variable
+ str, Callable[[str], List[str]]
+ ]:
loaded_chunks = set()
def get_files(bundle: str, asset_type: str = "js") -> List[str]:
@@ -131,7 +133,7 @@ appbuilder = AppBuilder(update_perms=False)
cache_manager = CacheManager()
celery_app = celery.Celery()
db = SQLA()
-_event_logger: dict = {}
+_event_logger: Dict[str, Any] = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
jinja_context_manager = JinjaContextManager()
diff --git a/superset/models/core.py b/superset/models/core.py
index 015abc2..ed90e4a 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -341,11 +341,14 @@ class Database(
def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words
- def get_quoter(self) -> Callable:
+ def get_quoter(self) -> Callable[[str, Any], str]:
return self.get_dialect().identifier_preparer.quote
def get_df( # pylint: disable=too-many-locals
- self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable] = None
+ self,
+ sql: str,
+ schema: Optional[str] = None,
+ mutator: Optional[Callable[[pd.DataFrame], None]] = None,
) -> pd.DataFrame:
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
@@ -450,7 +453,7 @@ class Database(
@cache_util.memoized_func(
key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
- attribute_in_key="id", # type: ignore
+ attribute_in_key="id",
)
def get_all_view_names_in_database(
self,
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index de42285..d10809e 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -240,7 +240,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
self.json_metadata = value
@property
- def position(self) -> Dict:
+ def position(self) -> Dict[str, Any]:
if self.position_json:
return json.loads(self.position_json)
return {}
@@ -315,7 +315,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
old_to_new_slc_id_dict: Dict[int, int] = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
- new_filter_scopes: Dict[str, Dict] = {}
+ new_filter_scopes = {}
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
@@ -351,7 +351,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
# are converted to filter_scopes
# but dashboard create from import may still have old dashboard filter metadata
# here we convert them to new filter_scopes metadata first
- filter_scopes: Dict = {}
+ filter_scopes = {}
if (
"filter_immune_slices" in i_params_dict
or "filter_immune_slice_fields" in i_params_dict
@@ -415,7 +415,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
@classmethod
def export_dashboards( # pylint: disable=too-many-locals
- cls, dashboard_ids: List
+ cls, dashboard_ids: List[int]
) -> str:
copied_dashboards = []
datasource_ids = set()
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 42169e6..4ffe5e9 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -81,7 +81,7 @@ class ImportMixin:
for u in cls.__table_args__ # type: ignore
if isinstance(u, UniqueConstraint)
]
- unique.extend( # type: ignore
+ unique.extend(
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
)
return unique
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 76eb457..4f73e43 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -36,7 +36,7 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils import core as utils
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
- from superset.viz_sip38 import BaseViz, viz_types # type: ignore
+ from superset.viz_sip38 import BaseViz, viz_types
else:
from superset.viz import BaseViz, viz_types # type: ignore
diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py
index a50b4c2..47486cf 100644
--- a/superset/models/sql_types/presto_sql_types.py
+++ b/superset/models/sql_types/presto_sql_types.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Optional, Type
+from typing import Any, Dict, List, Optional, Type
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
@@ -29,7 +29,7 @@ class TinyInteger(Integer):
A type for tiny ``int`` integers.
"""
- def python_type(self) -> Type:
+ def python_type(self) -> Type[int]:
return int
@classmethod
@@ -42,7 +42,7 @@ class Interval(TypeEngine):
A type for intervals.
"""
- def python_type(self) -> Optional[Type]:
+ def python_type(self) -> Optional[Type[Any]]:
return None
@classmethod
@@ -55,7 +55,7 @@ class Array(TypeEngine):
A type for arrays.
"""
- def python_type(self) -> Optional[Type]:
+ def python_type(self) -> Optional[Type[List[Any]]]:
return list
@classmethod
@@ -68,7 +68,7 @@ class Map(TypeEngine):
A type for maps.
"""
- def python_type(self) -> Optional[Type]:
+ def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
return dict
@classmethod
@@ -81,7 +81,7 @@ class Row(TypeEngine):
A type for rows.
"""
- def python_type(self) -> Optional[Type]:
+ def python_type(self) -> Optional[Type[Any]]:
return None
@classmethod
diff --git a/superset/queries/filters.py b/superset/queries/filters.py
index 323c3c6..22cf45f 100644
--- a/superset/queries/filters.py
+++ b/superset/queries/filters.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Callable
+from typing import Any
from flask import g
from flask_sqlalchemy import BaseQuery
@@ -25,7 +25,7 @@ from superset.views.base import BaseFilter
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
- def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
+ def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
"""
Filter queries to only those owned by current user. If
can_access_all_queries permission is set a user can list all queries
diff --git a/superset/result_set.py b/superset/result_set.py
index 4880511..dd6f0ff 100644
--- a/superset/result_set.py
+++ b/superset/result_set.py
@@ -20,7 +20,7 @@
import datetime
import json
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
import numpy as np
import pandas as pd
@@ -64,7 +64,7 @@ def stringify(obj: Any) -> str:
def stringify_values(array: np.ndarray) -> np.ndarray:
- vstringify: Callable = np.vectorize(stringify)
+ vstringify = np.vectorize(stringify)
return vstringify(array)
@@ -172,7 +172,7 @@ class SupersetResultSet:
return table.to_pandas(integer_object_nulls=True)
@staticmethod
- def first_nonempty(items: List) -> Any:
+ def first_nonempty(items: List[Any]) -> Any:
return next((i for i in items if i), None)
def is_temporal(self, db_type_str: Optional[str]) -> bool:
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 51cb3e0..772edbb 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -21,11 +21,11 @@ from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Uni
from flask import current_app, g
from flask_appbuilder import Model
-from flask_appbuilder.security.sqla import models as ab_models
from flask_appbuilder.security.sqla.manager import SecurityManager
from flask_appbuilder.security.sqla.models import (
assoc_permissionview_role,
assoc_user_role,
+ PermissionView,
)
from flask_appbuilder.security.views import (
PermissionModelView,
@@ -602,11 +602,8 @@ class SupersetSecurityManager(SecurityManager):
logger.info("Cleaning faulty perms")
sesh = self.get_session
- pvms = sesh.query(ab_models.PermissionView).filter(
- or_(
- ab_models.PermissionView.permission == None,
- ab_models.PermissionView.view_menu == None,
- )
+ pvms = sesh.query(PermissionView).filter(
+ or_(PermissionView.permission == None, PermissionView.view_menu == None,)
)
deleted_count = pvms.delete()
sesh.commit()
@@ -640,7 +637,9 @@ class SupersetSecurityManager(SecurityManager):
self.get_session.commit()
self.clean_perms()
- def set_role(self, role_name: str, pvm_check: Callable) -> None:
+ def set_role(
+ self, role_name: str, pvm_check: Callable[[PermissionView], bool]
+ ) -> None:
"""
Set the FAB permission/views for the role.
@@ -650,7 +649,7 @@ class SupersetSecurityManager(SecurityManager):
logger.info("Syncing {} perms".format(role_name))
sesh = self.get_session
- pvms = sesh.query(ab_models.PermissionView).all()
+ pvms = sesh.query(PermissionView).all()
pvms = [p for p in pvms if p.permission and p.view_menu]
role = self.add_role(role_name)
role_pvms = [p for p in pvms if pvm_check(p)]
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 3ba1e3a..d9f5b38 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -299,9 +299,10 @@ def _serialize_and_expand_data(
db_engine_spec: BaseEngineSpec,
use_msgpack: Optional[bool] = False,
expand_data: bool = False,
-) -> Tuple[Union[bytes, str], list, list, list]:
- selected_columns: List[Dict] = result_set.columns
- expanded_columns: List[Dict]
+) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
+ selected_columns = result_set.columns
+ all_columns: List[Any]
+ expanded_columns: List[Any]
if use_msgpack:
with stats_timing(
diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py
index 0344b59..0f3cd0e 100644
--- a/superset/tasks/celery_app.py
+++ b/superset/tasks/celery_app.py
@@ -25,7 +25,7 @@ from superset import create_app
from superset.extensions import celery_app
# Init the Flask app / configure everything
-create_app() # type: ignore
+create_app()
# Need to import late, as the celery_app will have been setup by "create_app()"
# pylint: disable=wrong-import-position, unused-import
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 2a5733e..3e6c1dd 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -23,7 +23,7 @@ import urllib.request
from collections import namedtuple
from datetime import datetime, timedelta
from email.utils import make_msgid, parseaddr
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib.error import URLError # pylint: disable=ungrouped-imports
import croniter
@@ -36,7 +36,6 @@ from flask_login import login_user
from retry.api import retry_call
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
-from werkzeug.datastructures import TypeConversionDict
from werkzeug.http import parse_cookie
# Superset framework imports
@@ -53,6 +52,11 @@ from superset.models.schedules import (
)
from superset.utils.core import get_email_address_list, send_email_smtp
+if TYPE_CHECKING:
+ # pylint: disable=unused-import
+ from werkzeug.datastructures import TypeConversionDict
+
+
# Globals
config = app.config
logger = logging.getLogger("tasks.email_reports")
@@ -131,7 +135,7 @@ def _generate_mail_content(
return EmailContent(body, data, images)
-def _get_auth_cookies() -> List[TypeConversionDict]:
+def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
# Login with the user specified to get the reports
with app.test_request_context():
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index bd39f87..586cb2b 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-
def memoized_func(
- key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
-) -> Callable:
+ key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace
+ attribute_in_key: Optional[str] = None,
+) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg.
enable_cache is treated as True by default,
@@ -45,7 +46,7 @@ def memoized_func(
returns the caching key.
"""
- def wrap(f: Callable) -> Callable:
+ def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
if cache_manager.tables_cache:
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 00e3484..1620542 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -85,7 +85,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
-from superset.typing import FormData, Metric
+from superset.typing import FlaskResponse, FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
try:
@@ -147,7 +147,9 @@ class _memoized:
should account for instance variable changes.
"""
- def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None:
+ def __init__(
+ self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
+ ) -> None:
self.func = func
self.cache: Dict[Any, Any] = {}
self.is_method = False
@@ -173,7 +175,7 @@ class _memoized:
"""Return the function's docstring."""
return self.func.__doc__ or ""
- def __get__(self, obj: Any, objtype: Type) -> functools.partial:
+ def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
if not self.is_method:
self.is_method = True
"""Support instance methods."""
@@ -181,13 +183,13 @@ class _memoized:
def memoized(
- func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None
-) -> Callable:
+ func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
+) -> Callable[..., Any]:
if func:
return _memoized(func)
else:
- def wrapper(f: Callable) -> Callable:
+ def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)
return wrapper
@@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str:
return path
-def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
+def time_function(
+ func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
+) -> Tuple[float, Any]:
"""
Measures the amount of time a function takes to execute in ms
diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py
index f77e0e0..d95582d 100644
--- a/superset/utils/dashboard_filter_scopes_converter.py
+++ b/superset/utils/dashboard_filter_scopes_converter.py
@@ -29,7 +29,7 @@ def convert_filter_scopes(
) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
- immuned_by_column: Dict = defaultdict(list)
+ immuned_by_column: Dict[str, List[int]] = defaultdict(list)
for slice_id, columns in json_metadata.get(
"filter_immune_slice_fields", {}
).items():
@@ -52,7 +52,7 @@ def convert_filter_scopes(
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
for filter_slice in filters:
- filter_fields: Dict = {}
+ filter_fields: Dict[str, Dict[str, Any]] = {}
filter_id = filter_slice.id
slice_params = json.loads(filter_slice.params or "{}")
configs = slice_params.get("filter_configs") or []
@@ -77,9 +77,10 @@ def convert_filter_scopes(
def copy_filter_scopes(
- old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict]
-) -> Dict:
- new_filter_scopes: Dict[str, Dict] = {}
+ old_to_new_slc_id_dict: Dict[int, int],
+ old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
+) -> Dict[str, Dict[Any, Any]]:
+ new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
for (filter_id, scopes) in old_filter_scopes.items():
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
if new_filter_key:
diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py
index a1165c5..bb0219c 100644
--- a/superset/utils/decorators.py
+++ b/superset/utils/decorators.py
@@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa
stats_logger.timing(stats_key, now_as_float() - start_ts)
-def etag_cache(max_age: int, check_perms: Callable) -> Callable:
+def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator for caching views and handling etag conditional requests.
@@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable:
"""
- def decorator(f: Callable) -> Callable:
+ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource
diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py
index 19f6d59..50f375c 100644
--- a/superset/utils/import_datasource.py
+++ b/superset/utils/import_datasource.py
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
def import_datasource(
session: Session,
i_datasource: Model,
- lookup_database: Callable,
- lookup_datasource: Callable,
+ lookup_database: Callable[[Model], Model],
+ lookup_datasource: Callable[[Model], Model],
import_time: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
@@ -82,7 +82,9 @@ def import_datasource(
return datasource.id
-def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
+def import_simple_obj(
+ session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
+) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
diff --git a/superset/utils/log.py b/superset/utils/log.py
index aafe3b8..b31abce 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -35,7 +35,7 @@ class AbstractEventLogger(ABC):
) -> None:
pass
- def log_this(self, f: Callable) -> Callable:
+ def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None
@@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
)
)
- event_logger_type = cast(Type, cfg_value)
+ event_logger_type = cast(Type[Any], cfg_value)
result = event_logger_type()
# Verify that we have a valid logger impl
diff --git a/superset/utils/logging_configurator.py b/superset/utils/logging_configurator.py
index 396d35e..09f1e58 100644
--- a/superset/utils/logging_configurator.py
+++ b/superset/utils/logging_configurator.py
@@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator):
if app_config["ENABLE_TIME_ROTATE"]:
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
- handler = TimedRotatingFileHandler( # type: ignore
+ handler = TimedRotatingFileHandler(
app_config["FILENAME"],
when=app_config["ROLLOVER"],
interval=app_config["INTERVAL"],
diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py
index 39a4278..e62b393 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
)
-def validate_column_args(*argnames: str) -> Callable:
- def wrapper(func: Callable) -> Callable:
+def validate_column_args(*argnames: str) -> Callable[..., Any]:
+ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
@@ -471,7 +471,7 @@ def geodetic_parse(
Parse a string containing a geodetic point and return latitude, longitude
and altitude
"""
- point = Point(location) # type: ignore
+ point = Point(location)
return point[0], point[1], point[2]
try:
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index e07d2a2..b2f222f 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3
WindowSize = Tuple[int, int]
-def get_auth_cookies(user: "User") -> List[Dict]:
+def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
# Login with the user specified to get the reports
with current_app.test_request_context("/login"):
login_user(user)
@@ -101,14 +101,14 @@ class AuthWebDriverProxy:
self,
driver_type: str,
window: Optional[WindowSize] = None,
- auth_func: Optional[Callable] = None,
+ auth_func: Optional[
+ Callable[..., Any]
+ ] = None, # pylint: disable=bad-whitespace
):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
- config_auth_func: Callable = current_app.config.get(
- "WEBDRIVER_AUTH_FUNC", auth_driver
- )
- self._auth_func: Callable = auth_func or config_auth_func
+ config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
+ self._auth_func = auth_func or config_auth_func
def create(self) -> WebDriver:
if self._driver_type == "firefox":
@@ -123,7 +123,7 @@ class AuthWebDriverProxy:
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
# Prepare args for the webdriver init
options.add_argument("--headless")
- kwargs: Dict = dict(options=options)
+ kwargs: Dict[Any, Any] = dict(options=options)
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
logger.info("Init selenium driver")
return driver_class(**kwargs)
diff --git a/superset/views/base.py b/superset/views/base.py
index 1238218..bbf18c1 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -143,7 +143,7 @@ def generate_download_headers(
return headers
-def api(f: Callable) -> Callable:
+def api(f: Callable[..., FlaskResponse]) -> Callable[..., FlaskResponse]:
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
return the response in the JSON format
@@ -383,11 +383,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
:param primary_key:
record primary key to delete
"""
- item = self.datamodel.get(primary_key, self._base_filters) # type: ignore
+ item = self.datamodel.get(primary_key, self._base_filters)
if not item:
abort(404)
try:
- self.pre_delete(item) # type: ignore
+ self.pre_delete(item)
except Exception as ex: # pylint: disable=broad-except
flash(str(ex), "danger")
else:
@@ -400,8 +400,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
.all()
)
- if self.datamodel.delete(item): # type: ignore
- self.post_delete(item) # type: ignore
+ if self.datamodel.delete(item):
+ self.post_delete(item)
for pv in pvs:
security_manager.get_session.delete(pv)
@@ -411,8 +411,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
security_manager.get_session.commit()
- flash(*self.datamodel.message) # type: ignore
- self.update_redirect() # type: ignore
+ flash(*self.datamodel.message)
+ self.update_redirect()
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 3d40c33..a72a1c5 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -41,7 +41,7 @@ get_related_schema = {
}
-def statsd_metrics(f: Callable) -> Callable:
+def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Handle sending all statsd metrics from the REST API
"""
diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py
index a4436dd..87a9190 100644
--- a/superset/views/base_schemas.py
+++ b/superset/views/base_schemas.py
@@ -88,7 +88,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
owners_field_name = "owners"
@post_load
- def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
+ def make_object(
+ self, data: Dict[str, Any], discard: Optional[List[str]] = None
+ ) -> Model:
discard = discard or []
discard.append(self.owners_field_name)
instance = super().make_object(data, discard)
diff --git a/superset/views/core.py b/superset/views/core.py
index c561222..44cade9 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -251,7 +251,7 @@ def check_slice_perms(self: "Superset", slice_id: int) -> None:
def _deserialize_results_payload(
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
-) -> Dict[Any, Any]:
+) -> Dict[str, Any]:
logger.debug(f"Deserializing from msgpack: {use_msgpack}")
if use_msgpack:
with stats_timing(
@@ -278,7 +278,7 @@ def _deserialize_results_payload(
with stats_timing(
"sqllab.query.results_backend_json_deserialize", stats_logger
):
- return json.loads(payload) # type: ignore
+ return json.loads(payload)
def get_cta_schema_name(
@@ -1343,7 +1343,7 @@ class Superset(BaseSupersetView):
if "timed_refresh_immune_slices" not in md:
md["timed_refresh_immune_slices"] = []
- new_filter_scopes: Dict[str, Dict] = {}
+ new_filter_scopes = {}
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
@@ -2137,7 +2137,7 @@ class Superset(BaseSupersetView):
f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
)
return json_error_response("Not found", 404)
- schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore
+ schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(
@@ -2245,7 +2245,7 @@ class Superset(BaseSupersetView):
)
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
- obj: dict = _deserialize_results_payload(
+ obj = _deserialize_results_payload(
payload, query, cast(bool, results_backend_use_msgpack)
)
@@ -2474,9 +2474,7 @@ class Superset(BaseSupersetView):
schema: str = cast(str, query_params.get("schema"))
sql: str = cast(str, query_params.get("sql"))
try:
- template_params: dict = json.loads(
- query_params.get("templateParams") or "{}"
- )
+ template_params = json.loads(query_params.get("templateParams") or "{}")
except json.JSONDecodeError:
logger.warning(
f"Invalid template parameter {query_params.get('templateParams')}"
diff --git a/superset/views/database/api.py b/superset/views/database/api.py
index 0050326..fe328d6 100644
--- a/superset/views/database/api.py
+++ b/superset/views/database/api.py
@@ -61,7 +61,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
def get_table_metadata(
database: Database, table_name: str, schema_name: Optional[str]
-) -> Dict:
+) -> Dict[str, Any]:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
@@ -72,7 +72,7 @@ def get_table_metadata(
:param schema_name: schema name
:return: Dict table metadata ready for API response
"""
- keys: List = []
+ keys = []
columns = database.get_columns(table_name, schema_name)
primary_key = database.get_pk_constraint(table_name, schema_name)
if primary_key and primary_key.get("constrained_columns"):
@@ -82,7 +82,7 @@ def get_table_metadata(
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes
- payload_columns: List[Dict] = []
+ payload_columns: List[Dict[str, Any]] = []
for col in columns:
dtype = get_col_type(col)
payload_columns.append(
@@ -90,7 +90,7 @@ def get_table_metadata(
"name": col["name"],
"type": dtype.split("(")[0] if "(" in dtype else dtype,
"longType": dtype,
- "keys": [k for k in keys if col["name"] in k.get("column_names")],
+ "keys": [k for k in keys if col["name"] in k["column_names"]],
}
)
return {
@@ -270,7 +270,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
"""
self.incr_stats("init", self.table_metadata.__name__)
try:
- table_info: Dict = get_table_metadata(database, table_name, schema_name)
+ table_info = get_table_metadata(database, table_name, schema_name)
except SQLAlchemyError as ex:
self.incr_stats("error", self.table_metadata.__name__)
return self.response_422(error_msg_from_exception(ex))
diff --git a/superset/views/database/decorators.py b/superset/views/database/decorators.py
index 0d2e83b..291a1af 100644
--- a/superset/views/database/decorators.py
+++ b/superset/views/database/decorators.py
@@ -29,7 +29,7 @@ from superset.views.base_api import BaseSupersetModelRestApi
logger = logging.getLogger(__name__)
-def check_datasource_access(f: Callable) -> Callable:
+def check_datasource_access(f: Callable[..., Any]) -> Callable[..., Any]:
"""
A Decorator that checks if a user has datasource access
"""
diff --git a/superset/views/schedules.py b/superset/views/schedules.py
index 68ae6ff..de09c31 100644
--- a/superset/views/schedules.py
+++ b/superset/views/schedules.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import enum
-from typing import Type
+from typing import Type, Union
import simplejson as json
from croniter import croniter
@@ -55,7 +55,7 @@ class EmailScheduleView(
raise NotImplementedError()
@property
- def schedule_type_model(self) -> Type:
+ def schedule_type_model(self) -> Type[Union[Dashboard, Slice]]:
raise NotImplementedError()
page_size = 20
@@ -154,9 +154,7 @@ class EmailScheduleView(
info[col] = info[col].username
info["user"] = schedule.user.username
- info[self.schedule_type] = getattr( # type: ignore
- schedule, self.schedule_type
- ).id
+ info[self.schedule_type] = getattr(schedule, self.schedule_type).id
schedules.append(info)
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))
diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py
index 3476bb3..534c6fb 100644
--- a/superset/views/sql_lab.py
+++ b/superset/views/sql_lab.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Callable
+from typing import Any
import simplejson as json
from flask import g, redirect, request, Response
@@ -40,7 +40,7 @@ from .base import (
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
- def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
+ def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
"""
Filter queries to only those owned by current user. If
can_access_all_queries permission is set a user can list all queries
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 4edd2e7..2a8b2bf 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -35,7 +35,7 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint
from superset.viz import BaseViz
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
- from superset import viz_sip38 as viz # type: ignore
+ from superset import viz_sip38 as viz
else:
from superset import viz # type: ignore
@@ -318,9 +318,9 @@ def get_dashboard_extra_filters(
def build_extra_filters(
- layout: Dict,
- filter_scopes: Dict,
- default_filters: Dict[str, Dict[str, List]],
+ layout: Dict[str, Dict[str, Any]],
+ filter_scopes: Dict[str, Dict[str, Any]],
+ default_filters: Dict[str, Dict[str, List[Any]]],
slice_id: int,
) -> List[Dict[str, Any]]:
extra_filters = []
@@ -343,7 +343,9 @@ def build_extra_filters(
return extra_filters
-def is_slice_in_container(layout: Dict, container_id: str, slice_id: int) -> bool:
+def is_slice_in_container(
+ layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int
+) -> bool:
if container_id == "ROOT_ID":
return True
diff --git a/superset/viz.py b/superset/viz.py
index d53dcf2..d34405c 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -2720,7 +2720,7 @@ class PairedTTestViz(BaseViz):
else:
cols.append(col)
df.columns = cols
- data: Dict = {}
+ data: Dict[str, List[Dict[str, Any]]] = {}
series = df.to_dict("series")
for nameSet in df.columns:
# If no groups are defined, nameSet will be the metric name
@@ -2750,7 +2750,7 @@ class RoseViz(NVD3TimeSeriesViz):
return None
data = super().get_data(df)
- result: Dict = {}
+ result: Dict[str, List[Dict[str, str]]] = {}
for datum in data: # type: ignore
key = datum["key"]
for val in datum["values"]:
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 88c5f7b..d6d6516 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -18,7 +18,7 @@
"""Unit tests for Superset"""
import imp
import json
-from typing import Dict, Union, List
+from typing import Any, Dict, Union, List
from unittest.mock import Mock, patch
import pandas as pd
@@ -397,7 +397,9 @@ class SupersetTestCase(TestCase):
mock_method.assert_called_once_with("error", func_name)
return rv
- def post_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
+ def post_assert_metric(
+ self, uri: str, data: Dict[str, Any], func_name: str
+ ) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
@@ -417,7 +419,9 @@ class SupersetTestCase(TestCase):
mock_method.assert_called_once_with("error", func_name)
return rv
- def put_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
+ def put_assert_metric(
+ self, uri: str, data: Dict[str, Any], func_name: str
+ ) -> Response:
"""
Simple client put with an extra assertion for statsd metrics
diff --git a/tests/superset_test_config_thumbnails.py b/tests/superset_test_config_thumbnails.py
index bfcb3a3..3b97604 100644
--- a/tests/superset_test_config_thumbnails.py
+++ b/tests/superset_test_config_thumbnails.py
@@ -20,7 +20,7 @@ from copy import copy
from cachelib.redis import RedisCache
from flask import Flask
-from superset.config import * # type: ignore
+from superset.config import *
AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")