You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/05/28 05:57:51 UTC
[incubator-superset] branch master updated: [mypy] Enforcing typing
for superset.utils (#9905)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new b296a0f [mypy] Enforcing typing for superset.utils (#9905)
b296a0f is described below
commit b296a0f250979bf70e9cb2a2a2b48fd10038a363
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed May 27 22:57:30 2020 -0700
[mypy] Enforcing typing for superset.utils (#9905)
Co-authored-by: John Bodley <jo...@airbnb.com>
---
setup.cfg | 2 +-
superset/config.py | 2 +-
superset/typing.py | 1 +
superset/utils/cache.py | 10 +-
superset/utils/core.py | 161 +++++++++++----------
.../utils/dashboard_filter_scopes_converter.py | 22 +--
superset/utils/dashboard_import_export.py | 12 +-
superset/utils/dates.py | 4 +-
superset/utils/decorators.py | 21 ++-
superset/utils/dict_import_export.py | 15 +-
superset/utils/feature_flag_manager.py | 24 ++-
superset/utils/log.py | 30 ++--
superset/utils/pandas_postprocessing.py | 6 +-
superset/utils/screenshots.py | 6 +-
superset/utils/url_map_converters.py | 12 +-
superset/views/core.py | 2 +-
tests/utils_tests.py | 12 --
17 files changed, 194 insertions(+), 148 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index 99c09ce..1115de9 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
-[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
+[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
diff --git a/superset/config.py b/superset/config.py
index e3d3ffb..e0f22f7 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -279,7 +279,7 @@ LANGUAGES = {
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
-DEFAULT_FEATURE_FLAGS = {
+DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
# Experimental feature introducing a client (browser) cache
"CLIENT_CACHE": False,
"ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,
diff --git a/superset/typing.py b/superset/typing.py
index f3db6ae..09a3393 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -28,6 +28,7 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, .
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
FilterValue = Union[float, int, str]
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
+FormData = Dict[str, Any]
Granularity = Union[str, Dict[str, Union[str, float]]]
Metric = Union[Dict[str, str], str]
QueryObjectDict = Dict[str, Any]
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index b555005..bd39f87 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Callable, Optional
+from typing import Any, Callable, Optional
from flask import request
from superset.extensions import cache_manager
-def view_cache_key(*_, **__) -> str:
+def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument
args_hash = hash(frozenset(request.args.items()))
return "view/{}/{}".format(request.path, args_hash)
@@ -45,10 +45,10 @@ def memoized_func(
returns the caching key.
"""
- def wrap(f):
+ def wrap(f: Callable) -> Callable:
if cache_manager.tables_cache:
- def wrapped_f(self, *args, **kwargs):
+ def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
if not kwargs.get("cache", True):
return f(self, *args, **kwargs)
@@ -69,7 +69,7 @@ def memoized_func(
else:
# noop
- def wrapped_f(self, *args, **kwargs):
+ def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
return f(self, *args, **kwargs)
return wrapped_f
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 3618a28..b23136d 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -39,6 +39,7 @@ from email.utils import formatdate
from enum import Enum
from time import struct_time
from timeit import default_timer
+from types import TracebackType
from typing import (
Any,
Callable,
@@ -51,6 +52,7 @@ from typing import (
Sequence,
Set,
Tuple,
+ Type,
TYPE_CHECKING,
Union,
)
@@ -69,10 +71,12 @@ from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import current_app, flash, g, Markup, render_template
from flask_appbuilder import SQLA
-from flask_appbuilder.security.sqla.models import User
+from flask_appbuilder.security.sqla.models import Role, User
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import event, exc, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
+from sqlalchemy.engine import Connection, Engine
+from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator
@@ -81,7 +85,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
-from superset.typing import Metric
+from superset.typing import FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
try:
@@ -90,6 +94,7 @@ except ImportError:
pass
if TYPE_CHECKING:
+ from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database
@@ -121,7 +126,7 @@ except NameError:
pass
-def flasher(msg: str, severity: str) -> None:
+def flasher(msg: str, severity: str = "message") -> None:
"""Flask's flash if available, logging call if not"""
try:
flash(msg, severity)
@@ -142,17 +147,17 @@ class _memoized:
should account for instance variable changes.
"""
- def __init__(self, func, watch=()):
+ def __init__(self, func: Callable, watch: Optional[List[str]] = None) -> None:
self.func = func
- self.cache = {}
+ self.cache: Dict[Any, Any] = {}
self.is_method = False
self.watch = watch or []
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
- key = tuple(key)
+ key = tuple(key) # type: ignore
if key in self.cache:
return self.cache[key]
try:
@@ -164,23 +169,25 @@ class _memoized:
# Better to not cache than to blow up entirely.
return self.func(*args, **kwargs)
- def __repr__(self):
+ def __repr__(self) -> str:
"""Return the function's docstring."""
- return self.func.__doc__
+ return self.func.__doc__ or ""
- def __get__(self, obj, objtype):
+ def __get__(self, obj: Any, objtype: Type) -> functools.partial:
if not self.is_method:
self.is_method = True
"""Support instance methods."""
return functools.partial(self.__call__, obj)
-def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None):
+def memoized(
+ func: Optional[Callable] = None, watch: Optional[List[str]] = None
+) -> Callable:
if func:
return _memoized(func)
else:
- def wrapper(f):
+ def wrapper(f: Callable) -> Callable:
return _memoized(f, watch)
return wrapper
@@ -229,7 +236,7 @@ def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
return None
-def list_minus(l: List, minus: List) -> List:
+def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
"""Returns l without what is in minus
>>> list_minus([1, 2, 3], [2])
@@ -284,19 +291,19 @@ def md5_hex(data: str) -> str:
class DashboardEncoder(json.JSONEncoder):
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.sort_keys = True
# pylint: disable=E0202
- def default(self, o):
+ def default(self, o: Any) -> Dict[Any, Any]:
try:
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
return {"__{}__".format(o.__class__.__name__): vals}
except Exception:
if type(o) == datetime:
return {"__datetime__": o.replace(microsecond=0).isoformat()}
- return json.JSONEncoder(sort_keys=True).default(self, o)
+ return json.JSONEncoder(sort_keys=True).default(o)
def parse_human_timedelta(s: Optional[str]) -> timedelta:
@@ -332,28 +339,15 @@ class JSONEncodedDict(TypeDecorator):
impl = TEXT
- def process_bind_param(self, value, dialect):
- if value is not None:
- value = json.dumps(value)
+ def process_bind_param(
+ self, value: Optional[Dict[Any, Any]], dialect: str
+ ) -> Optional[str]:
+ return json.dumps(value) if value is not None else None
- return value
-
- def process_result_value(self, value, dialect):
- if value is not None:
- value = json.loads(value)
- return value
-
-
-def datetime_f(dttm):
- """Formats datetime to take less room when it is recent"""
- if dttm:
- dttm = dttm.isoformat()
- now_iso = datetime.now().isoformat()
- if now_iso[:10] == dttm[:10]:
- dttm = dttm[11:]
- elif now_iso[:4] == dttm[:4]:
- dttm = dttm[5:]
- return "<nobr>{}</nobr>".format(dttm)
+ def process_result_value(
+ self, value: Optional[str], dialect: str
+ ) -> Optional[Dict[Any, Any]]:
+ return json.loads(value) if value is not None else None
def format_timedelta(td: timedelta) -> str:
@@ -373,7 +367,7 @@ def format_timedelta(td: timedelta) -> str:
return str(td)
-def base_json_conv(obj):
+def base_json_conv(obj: Any) -> Any:
if isinstance(obj, memoryview):
obj = obj.tobytes()
if isinstance(obj, np.int64):
@@ -397,7 +391,7 @@ def base_json_conv(obj):
return "[bytes]"
-def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
+def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str:
"""
json serializer that deals with dates
@@ -420,14 +414,14 @@ def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
return obj
-def pessimistic_json_iso_dttm_ser(obj):
+def pessimistic_json_iso_dttm_ser(obj: Any) -> str:
"""Proxy to call json_iso_dttm_ser in a pessimistic way
If one of object is not serializable to json, it will still succeed"""
return json_iso_dttm_ser(obj, pessimistic=True)
-def json_int_dttm_ser(obj):
+def json_int_dttm_ser(obj: Any) -> float:
"""json serializer that deals with dates"""
val = base_json_conv(obj)
if val is not None:
@@ -441,7 +435,7 @@ def json_int_dttm_ser(obj):
return obj
-def json_dumps_w_dates(payload):
+def json_dumps_w_dates(payload: Dict[Any, Any]) -> str:
return json.dumps(payload, default=json_int_dttm_ser)
@@ -522,7 +516,7 @@ def readfile(file_path: str) -> Optional[str]:
def generic_find_constraint_name(
table: str, columns: Set[str], referenced: str, db: SQLA
-):
+) -> Optional[str]:
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
@@ -530,10 +524,12 @@ def generic_find_constraint_name(
if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
return fk.name
+ return None
+
def generic_find_fk_constraint_name(
- table: str, columns: Set[str], referenced: str, insp
-):
+ table: str, columns: Set[str], referenced: str, insp: Inspector
+) -> Optional[str]:
"""Utility to find a foreign-key constraint name in alembic migrations"""
for fk in insp.get_foreign_keys(table):
if (
@@ -542,8 +538,12 @@ def generic_find_fk_constraint_name(
):
return fk["name"]
+ return None
+
-def generic_find_fk_constraint_names(table, columns, referenced, insp):
+def generic_find_fk_constraint_names(
+ table: str, columns: Set[str], referenced: str, insp: Inspector
+) -> Set[str]:
"""Utility to find foreign-key constraint names in alembic migrations"""
names = set()
@@ -557,13 +557,17 @@ def generic_find_fk_constraint_names(table, columns, referenced, insp):
return names
-def generic_find_uq_constraint_name(table, columns, insp):
+def generic_find_uq_constraint_name(
+ table: str, columns: Set[str], insp: Inspector
+) -> Optional[str]:
"""Utility to find a unique constraint name in alembic migrations"""
for uq in insp.get_unique_constraints(table):
if columns == set(uq["column_names"]):
return uq["name"]
+ return None
+
def get_datasource_full_name(
database_name: str, datasource_name: str, schema: Optional[str] = None
@@ -582,30 +586,20 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None:
raise SupersetException("JSON is not valid")
-def table_has_constraint(table, name, db):
- """Utility to find a constraint name in alembic migrations"""
- t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
-
- for c in t.constraints:
- if c.name == name:
- return True
- return False
-
-
class timeout:
"""
To be used in a ``with`` block and timeout its content.
"""
- def __init__(self, seconds=1, error_message="Timeout"):
+ def __init__(self, seconds: int = 1, error_message: str = "Timeout") -> None:
self.seconds = seconds
self.error_message = error_message
- def handle_timeout(self, signum, frame):
+ def handle_timeout(self, signum: int, frame: Any) -> None:
logger.error("Process timed out")
raise SupersetTimeoutException(self.error_message)
- def __enter__(self):
+ def __enter__(self) -> None:
try:
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
@@ -613,7 +607,7 @@ class timeout:
logger.warning("timeout can't be used in the current context")
logger.exception(ex)
- def __exit__(self, type, value, traceback):
+ def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None:
try:
signal.alarm(0)
except ValueError as ex:
@@ -621,9 +615,9 @@ class timeout:
logger.exception(ex)
-def pessimistic_connection_handling(some_engine):
+def pessimistic_connection_handling(some_engine: Engine) -> None:
@event.listens_for(some_engine, "engine_connect")
- def ping_connection(connection, branch):
+ def ping_connection(connection: Connection, branch: bool) -> None:
if branch:
# 'branch' refers to a sub-connection of a connection,
# we don't want to bother pinging on these.
@@ -670,7 +664,14 @@ class QueryStatus:
TIMED_OUT: str = "timed_out"
-def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, config):
+def notify_user_about_perm_udate(
+ granter: User,
+ user: User,
+ role: Role,
+ datasource: "BaseDatasource",
+ tpl_name: str,
+ config: Dict[str, Any],
+) -> None:
msg = render_template(
tpl_name, granter=granter, user=user, role=role, datasource=datasource
)
@@ -762,7 +763,13 @@ def send_email_smtp(
send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
-def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False):
+def send_MIME_email(
+ e_from: str,
+ e_to: List[str],
+ mime_msg: MIMEMultipart,
+ config: Dict[str, Any],
+ dryrun: bool = False,
+) -> None:
SMTP_HOST = config["SMTP_HOST"]
SMTP_PORT = config["SMTP_PORT"]
SMTP_USER = config["SMTP_USER"]
@@ -800,7 +807,7 @@ def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
return [(v, v) for v in values]
-def zlib_compress(data):
+def zlib_compress(data: Union[bytes, str]) -> bytes:
"""
Compress things in a py2/3 safe fashion
>>> json_str = '{"test": 1}'
@@ -827,7 +834,9 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
return decompressed.decode("utf-8") if decode else decompressed
-def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
+def to_adhoc(
+ filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where"
+) -> Dict[str, Any]:
result = {
"clause": clause.upper(),
"expressionType": expressionType,
@@ -849,7 +858,7 @@ def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
return result
-def merge_extra_filters(form_data: dict):
+def merge_extra_filters(form_data: Dict[str, Any]) -> None:
# extra_filters are temporary/contextual filters (using the legacy constructs)
# that are external to the slice definition. We use those for dynamic
# interactive filters like the ones emitted by the "Filter Box" visualization.
@@ -872,7 +881,7 @@ def merge_extra_filters(form_data: dict):
}
# Grab list of existing filters 'keyed' on the column and operator
- def get_filter_key(f):
+ def get_filter_key(f: Dict[str, Any]) -> str:
if "expressionType" in f:
return "{}__{}".format(f["subject"], f["operator"])
else:
@@ -945,7 +954,9 @@ def user_label(user: User) -> Optional[str]:
return None
-def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
+def get_or_create_db(
+ database_name: str, sqlalchemy_uri: str, *args: Any, **kwargs: Any
+) -> "Database":
from superset import db
from superset.models import core as models
@@ -996,7 +1007,7 @@ def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
return [get_metric_name(metric) for metric in metrics]
-def ensure_path_exists(path: str):
+def ensure_path_exists(path: str) -> None:
try:
os.makedirs(path)
except OSError as exc:
@@ -1119,7 +1130,7 @@ def add_ago_to_since(since: str) -> str:
return since
-def convert_legacy_filters_into_adhoc(fd):
+def convert_legacy_filters_into_adhoc(fd: FormData) -> None:
mapping = {"having": "having_filters", "where": "filters"}
if not fd.get("adhoc_filters"):
@@ -1138,7 +1149,7 @@ def convert_legacy_filters_into_adhoc(fd):
del fd[key]
-def split_adhoc_filters_into_base_filters(fd):
+def split_adhoc_filters_into_base_filters(fd: FormData) -> None:
"""
Mutates form data to restructure the adhoc filters in the form of the four base
filters, `where`, `having`, `filters`, and `having_filters` which represent
@@ -1230,7 +1241,7 @@ def create_ssl_cert_file(certificate: str) -> str:
return path
-def time_function(func: Callable, *args, **kwargs) -> Tuple[float, Any]:
+def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
"""
Measures the amount of time a function takes to execute in ms
@@ -1296,7 +1307,7 @@ def split(
yield s[i:]
-def get_iterable(x: Any) -> List:
+def get_iterable(x: Any) -> List[Any]:
"""
Get an iterable (list) representation of the object.
diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py
index 6954990..f77e0e0 100644
--- a/superset/utils/dashboard_filter_scopes_converter.py
+++ b/superset/utils/dashboard_filter_scopes_converter.py
@@ -17,14 +17,16 @@
import json
import logging
from collections import defaultdict
-from typing import Dict, List
+from typing import Any, Dict, List
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
-def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
+def convert_filter_scopes(
+ json_metadata: Dict[Any, Any], filters: List[Slice]
+) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
immuned_by_column: Dict = defaultdict(list)
@@ -34,7 +36,9 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
for column in columns:
immuned_by_column[column].append(int(slice_id))
- def add_filter_scope(filter_field, filter_id):
+ def add_filter_scope(
+ filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int
+ ) -> None:
# in case filter field is invalid
if isinstance(filter_field, str):
current_filter_immune = list(
@@ -54,17 +58,17 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
configs = slice_params.get("filter_configs") or []
if slice_params.get("date_filter"):
- add_filter_scope("__time_range", filter_id)
+ add_filter_scope(filter_fields, "__time_range", filter_id)
if slice_params.get("show_sqla_time_column"):
- add_filter_scope("__time_col", filter_id)
+ add_filter_scope(filter_fields, "__time_col", filter_id)
if slice_params.get("show_sqla_time_granularity"):
- add_filter_scope("__time_grain", filter_id)
+ add_filter_scope(filter_fields, "__time_grain", filter_id)
if slice_params.get("show_druid_time_granularity"):
- add_filter_scope("__granularity", filter_id)
+ add_filter_scope(filter_fields, "__granularity", filter_id)
if slice_params.get("show_druid_time_origin"):
- add_filter_scope("druid_time_origin", filter_id)
+ add_filter_scope(filter_fields, "druid_time_origin", filter_id)
for config in configs:
- add_filter_scope(config.get("column"), filter_id)
+ add_filter_scope(filter_fields, config.get("column"), filter_id)
if filter_fields:
filter_scopes[filter_id] = filter_fields
diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py
index 53100f8..19dd2e2 100644
--- a/superset/utils/dashboard_import_export.py
+++ b/superset/utils/dashboard_import_export.py
@@ -19,6 +19,10 @@ import json
import logging
import time
from datetime import datetime
+from io import BytesIO
+from typing import Any, Dict, Optional
+
+from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.dashboard import Dashboard
@@ -27,7 +31,7 @@ from superset.models.slice import Slice
logger = logging.getLogger(__name__)
-def decode_dashboards(o):
+def decode_dashboards(o: Dict[str, Any]) -> Any:
"""
Function to be passed into json.loads obj_hook parameter
Recreates the dashboard object from a json representation.
@@ -50,7 +54,9 @@ def decode_dashboards(o):
return o
-def import_dashboards(session, data_stream, import_time=None):
+def import_dashboards(
+ session: Session, data_stream: BytesIO, import_time: Optional[int] = None
+) -> None:
"""Imports dashboards from a stream to databases"""
current_tt = int(time.time())
import_time = current_tt if import_time is None else import_time
@@ -64,7 +70,7 @@ def import_dashboards(session, data_stream, import_time=None):
session.commit()
-def export_dashboards(session):
+def export_dashboards(session: Session) -> str:
"""Returns all dashboards metadata as a json dump"""
logger.info("Starting export")
dashboards = session.query(Dashboard)
diff --git a/superset/utils/dates.py b/superset/utils/dates.py
index a1826e2..021ec7f 100644
--- a/superset/utils/dates.py
+++ b/superset/utils/dates.py
@@ -21,7 +21,7 @@ import pytz
EPOCH = datetime(1970, 1, 1)
-def datetime_to_epoch(dttm):
+def datetime_to_epoch(dttm: datetime) -> float:
if dttm.tzinfo:
dttm = dttm.replace(tzinfo=pytz.utc)
epoch_with_tz = pytz.utc.localize(EPOCH)
@@ -29,5 +29,5 @@ def datetime_to_epoch(dttm):
return (dttm - EPOCH).total_seconds() * 1000
-def now_as_float():
+def now_as_float() -> float:
return datetime_to_epoch(datetime.utcnow())
diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py
index 52ba61f..a1165c5 100644
--- a/superset/utils/decorators.py
+++ b/superset/utils/decorators.py
@@ -17,11 +17,14 @@
import logging
from datetime import datetime, timedelta
from functools import wraps
+from typing import Any, Callable, Iterator
from contextlib2 import contextmanager
from flask import request
+from werkzeug.wrappers.etag import ETagResponseMixin
from superset import app, cache
+from superset.stats_logger import BaseStatsLogger
from superset.utils.dates import now_as_float
# If a user sets `max_age` to 0, for long the browser should cache the
@@ -32,7 +35,7 @@ logger = logging.getLogger(__name__)
@contextmanager
-def stats_timing(stats_key, stats_logger):
+def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]:
"""Provide a transactional scope around a series of operations."""
start_ts = now_as_float()
try:
@@ -43,7 +46,7 @@ def stats_timing(stats_key, stats_logger):
stats_logger.timing(stats_key, now_as_float() - start_ts)
-def etag_cache(max_age, check_perms=bool):
+def etag_cache(max_age: int, check_perms: Callable) -> Callable:
"""
A decorator for caching views and handling etag conditional requests.
@@ -57,9 +60,9 @@ def etag_cache(max_age, check_perms=bool):
"""
- def decorator(f):
+ def decorator(f: Callable) -> Callable:
@wraps(f)
- def wrapper(*args, **kwargs):
+ def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource
check_perms(*args, **kwargs)
@@ -77,7 +80,9 @@ def etag_cache(max_age, check_perms=bool):
key_args = list(args)
key_kwargs = kwargs.copy()
key_kwargs.update(request.args)
- cache_key = wrapper.make_cache_key(f, *key_args, **key_kwargs)
+ cache_key = wrapper.make_cache_key( # type: ignore
+ f, *key_args, **key_kwargs
+ )
response = cache.get(cache_key)
except Exception: # pylint: disable=broad-except
if app.debug:
@@ -109,9 +114,9 @@ def etag_cache(max_age, check_perms=bool):
return response.make_conditional(request)
if cache:
- wrapper.uncached = f
- wrapper.cache_timeout = max_age
- wrapper.make_cache_key = cache._memoize_make_cache_key( # pylint: disable=protected-access
+ wrapper.uncached = f # type: ignore
+ wrapper.cache_timeout = max_age # type: ignore
+ wrapper.make_cache_key = cache._memoize_make_cache_key( # type: ignore # pylint: disable=protected-access
make_name=None, timeout=max_age
)
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index d7ede85..a58635d 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -16,6 +16,9 @@
# under the License.
# pylint: disable=C,R,W
import logging
+from typing import Any, Dict, List, Optional
+
+from sqlalchemy.orm import Session
from superset.connectors.druid.models import DruidCluster
from superset.models.core import Database
@@ -25,7 +28,7 @@ DRUID_CLUSTERS_KEY = "druid_clusters"
logger = logging.getLogger(__name__)
-def export_schema_to_dict(back_references):
+def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
"""Exports the supported import/export schema to a dictionary"""
databases = [
Database.export_schema(recursive=True, include_parent_ref=back_references)
@@ -41,7 +44,9 @@ def export_schema_to_dict(back_references):
return data
-def export_to_dict(session, recursive, back_references, include_defaults):
+def export_to_dict(
+ session: Session, recursive: bool, back_references: bool, include_defaults: bool
+) -> Dict[str, Any]:
"""Exports databases and druid clusters to a dictionary"""
logger.info("Starting export")
dbs = session.query(Database)
@@ -72,8 +77,12 @@ def export_to_dict(session, recursive, back_references, include_defaults):
return data
-def import_from_dict(session, data, sync=[]):
+def import_from_dict(
+ session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
+) -> None:
"""Imports databases and druid clusters from dictionary"""
+ if not sync:
+ sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py
index 654607b..88f19c2 100644
--- a/superset/utils/feature_flag_manager.py
+++ b/superset/utils/feature_flag_manager.py
@@ -15,25 +15,33 @@
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
+from typing import Any, Dict
+
+from flask import Flask
class FeatureFlagManager:
def __init__(self) -> None:
super().__init__()
self._get_feature_flags_func = None
- self._feature_flags = None
+ self._feature_flags: Dict[str, Any] = {}
- def init_app(self, app):
- self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC")
- self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
- self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
+ def init_app(self, app: Flask) -> None:
+ self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"]
+ self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
+ self._feature_flags.update(app.config["FEATURE_FLAGS"])
- def get_feature_flags(self):
+ def get_feature_flags(self) -> Dict[str, Any]:
if self._get_feature_flags_func:
return self._get_feature_flags_func(deepcopy(self._feature_flags))
return self._feature_flags
- def is_feature_enabled(self, feature) -> bool:
+ def is_feature_enabled(self, feature: str) -> bool:
"""Utility function for checking whether a feature is turned on"""
- return self.get_feature_flags().get(feature)
+ feature_flags = self.get_feature_flags()
+
+ if feature_flags and feature in feature_flags:
+ return feature_flags[feature]
+
+ return False
diff --git a/superset/utils/log.py b/superset/utils/log.py
index 5d8c52e..aafe3b8 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -21,19 +21,23 @@ import logging
import textwrap
from abc import ABC, abstractmethod
from datetime import datetime
-from typing import Any, cast, Type
+from typing import Any, Callable, cast, Optional, Type
from flask import current_app, g, request
+from superset.stats_logger import BaseStatsLogger
+
class AbstractEventLogger(ABC):
@abstractmethod
- def log(self, user_id, action, *args, **kwargs):
+ def log(
+ self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
+ ) -> None:
pass
- def log_this(self, f):
+ def log_this(self, f: Callable) -> Callable:
@functools.wraps(f)
- def wrapper(*args, **kwargs):
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None
if g.user:
user_id = g.user.get_id()
@@ -49,7 +53,12 @@ class AbstractEventLogger(ABC):
try:
slice_id = int(
- slice_id or json.loads(form_data.get("form_data")).get("slice_id")
+ slice_id
+ or json.loads(
+ form_data.get("form_data") # type: ignore
+ ).get(
+ "slice_id"
+ )
)
except (ValueError, TypeError):
slice_id = 0
@@ -62,7 +71,7 @@ class AbstractEventLogger(ABC):
# bulk insert
try:
explode_by = form_data.get("explode")
- records = json.loads(form_data.get(explode_by))
+ records = json.loads(form_data.get(explode_by)) # type: ignore
except Exception: # pylint: disable=broad-except
records = [form_data]
@@ -82,11 +91,11 @@ class AbstractEventLogger(ABC):
return wrapper
@property
- def stats_logger(self):
+ def stats_logger(self) -> BaseStatsLogger:
return current_app.config["STATS_LOGGER"]
-def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
+def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
"""
This function implements the deprecation of assignment
of class objects to EVENT_LOGGER configuration, and validates
@@ -130,7 +139,9 @@ def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
class DBEventLogger(AbstractEventLogger):
- def log(self, user_id, action, *args, **kwargs): # pylint: disable=too-many-locals
+ def log( # pylint: disable=too-many-locals
+ self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
+ ) -> None:
from superset.models.core import Log
records = kwargs.get("records", list())
@@ -141,6 +152,7 @@ class DBEventLogger(AbstractEventLogger):
logs = list()
for record in records:
+ json_string: Optional[str]
try:
json_string = json.dumps(record)
except Exception: # pylint: disable=broad-except
diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py
index dabebed..39a4278 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -73,8 +73,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
def validate_column_args(*argnames: str) -> Callable:
- def wrapper(func):
- def wrapped(df, **options):
+ def wrapper(func: Callable) -> Callable:
+ def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
@@ -159,7 +159,7 @@ def pivot( # pylint: disable=too-many-arguments
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True,
- combine_value_with_metric=False,
+ combine_value_with_metric: bool = False,
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
) -> DataFrame:
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index 18283e7..e07d2a2 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -18,7 +18,7 @@ import logging
import time
import urllib.parse
from io import BytesIO
-from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
+from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from flask import current_app, request, Response, session, url_for
from flask_login import login_user
@@ -91,7 +91,7 @@ def headless_url(path: str) -> str:
return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path)
-def get_url_path(view: str, **kwargs) -> str:
+def get_url_path(view: str, **kwargs: Any) -> str:
with current_app.test_request_context():
return headless_url(url_for(view, **kwargs))
@@ -135,7 +135,7 @@ class AuthWebDriverProxy:
return self._auth_func(driver, user)
@staticmethod
- def destroy(driver: WebDriver, tries=2):
+ def destroy(driver: WebDriver, tries: int = 2) -> None:
"""Destroy a driver"""
# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions
diff --git a/superset/utils/url_map_converters.py b/superset/utils/url_map_converters.py
index 6697b7d..463dfcd 100644
--- a/superset/utils/url_map_converters.py
+++ b/superset/utils/url_map_converters.py
@@ -14,22 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from werkzeug.routing import BaseConverter
+from typing import Any, List
+
+from werkzeug.routing import BaseConverter, Map
from superset.models.tags import ObjectTypes
class RegexConverter(BaseConverter):
- def __init__(self, url_map, *items):
- super(RegexConverter, self).__init__(url_map)
+ def __init__(self, url_map: Map, *items: List[str]) -> None:
+ super(RegexConverter, self).__init__(url_map) # type: ignore
self.regex = items[0]
class ObjectTypeConverter(BaseConverter):
"""Validate that object_type is indeed an object type."""
- def to_python(self, value):
+ def to_python(self, value: str) -> Any:
return ObjectTypes[value]
- def to_url(self, value):
+ def to_url(self, value: Any) -> str:
return value.name
diff --git a/superset/views/core.py b/superset/views/core.py
index 716c11a..60454d9 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -2164,7 +2164,7 @@ class Superset(BaseSupersetView):
return json_error_response(str(ex))
spec = mydb.db_engine_spec
- query_cost_formatters = get_feature_flags().get(
+ query_cost_formatters: Dict[str, Any] = get_feature_flags().get(
"QUERY_COST_FORMATTERS_BY_ENGINE", {}
)
query_cost_formatter = query_cost_formatters.get(
diff --git a/tests/utils_tests.py b/tests/utils_tests.py
index 4a4b640..02d4a88 100644
--- a/tests/utils_tests.py
+++ b/tests/utils_tests.py
@@ -38,7 +38,6 @@ from superset.utils.core import (
base_json_conv,
convert_legacy_filters_into_adhoc,
create_ssl_cert_file,
- datetime_f,
format_timedelta,
get_iterable,
get_email_address_list,
@@ -560,17 +559,6 @@ class UtilsTestCase(SupersetTestCase):
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
)
- def test_datetime_f(self):
- self.assertEqual(
- datetime_f(datetime(1990, 9, 21, 19, 11, 19, 626096)),
- "<nobr>1990-09-21T19:11:19.626096</nobr>",
- )
- self.assertEqual(len(datetime_f(datetime.now())), 28)
- self.assertEqual(datetime_f(None), "<nobr>None</nobr>")
- iso = datetime.now().isoformat()[:10].split("-")
- [a, b, c] = [int(v) for v in iso]
- self.assertEqual(datetime_f(datetime(a, b, c)), "<nobr>00:00:00</nobr>")
-
def test_format_timedelta(self):
self.assertEqual(format_timedelta(timedelta(0)), "0:00:00")
self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00")