You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/06/03 22:26:32 UTC
[incubator-superset] branch master updated: style(mypy): Enforcing
typing for superset (#9943)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 244677c style(mypy): Enforcing typing for superset (#9943)
244677c is described below
commit 244677cf5e0ecb7c767455e96655af6c18cc58bc
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Jun 3 15:26:12 2020 -0700
style(mypy): Enforcing typing for superset (#9943)
Co-authored-by: John Bodley <jo...@airbnb.com>
---
setup.cfg | 2 +-
superset/app.py | 66 +++++----
superset/cli.py | 56 ++++----
superset/config.py | 21 +--
superset/exceptions.py | 2 +-
superset/extensions.py | 45 +++---
superset/forms.py | 16 +--
superset/jinja_context.py | 26 ++--
superset/sql_lab.py | 76 +++++-----
superset/sql_parse.py | 6 +-
superset/stats_logger.py | 27 ++--
superset/typing.py | 1 +
superset/viz.py | 357 ++++++++++++++++++++++++++--------------------
superset/viz_sip38.py | 3 +-
tests/viz_tests.py | 2 +-
15 files changed, 393 insertions(+), 313 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index 1115de9..fc94a24 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
-[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
+[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
diff --git a/superset/app.py b/superset/app.py
index 98f1459..18165ed 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -17,6 +17,7 @@
import logging
import os
+from typing import Any, Callable, Dict
import wtforms_json
from flask import Flask, redirect
@@ -41,13 +42,14 @@ from superset.extensions import (
talisman,
)
from superset.security import SupersetSecurityManager
+from superset.typing import FlaskResponse
from superset.utils.core import pessimistic_connection_handling
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
logger = logging.getLogger(__name__)
-def create_app():
+def create_app() -> Flask:
app = Flask(__name__)
try:
@@ -68,7 +70,7 @@ def create_app():
class SupersetIndexView(IndexView):
@expose("/")
- def index(self):
+ def index(self) -> FlaskResponse:
return redirect("/superset/welcome")
@@ -109,8 +111,8 @@ class SupersetAppInitializer:
abstract = True
# Grab each call into the task and set up an app context
- def __call__(self, *args, **kwargs):
- with flask_app.app_context():
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ with flask_app.app_context(): # type: ignore
return task_base.__call__(self, *args, **kwargs)
celery_app.Task = AppContextTask
@@ -454,51 +456,41 @@ class SupersetAppInitializer:
order to fully init the app
"""
self.pre_init()
-
self.setup_db()
-
self.configure_celery()
-
self.setup_event_logger()
-
self.setup_bundle_manifest()
-
self.register_blueprints()
-
self.configure_wtf()
-
self.configure_logging()
-
self.configure_middlewares()
-
self.configure_cache()
-
self.configure_jinja_context()
- with self.flask_app.app_context():
+ with self.flask_app.app_context(): # type: ignore
self.init_app_in_ctx()
self.post_init()
- def setup_event_logger(self):
+ def setup_event_logger(self) -> None:
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
)
- def configure_data_sources(self):
+ def configure_data_sources(self) -> None:
# Registering sources
module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
ConnectorRegistry.register_sources(module_datasource_map)
- def configure_cache(self):
+ def configure_cache(self) -> None:
cache_manager.init_app(self.flask_app)
results_backend_manager.init_app(self.flask_app)
- def configure_feature_flags(self):
+ def configure_feature_flags(self) -> None:
feature_flag_manager.init_app(self.flask_app)
- def configure_fab(self):
+ def configure_fab(self) -> None:
if self.config["SILENCE_FAB"]:
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
@@ -516,7 +508,7 @@ class SupersetAppInitializer:
appbuilder.update_perms = False
appbuilder.init_app(self.flask_app, db.session)
- def configure_url_map_converters(self):
+ def configure_url_map_converters(self) -> None:
#
# Doing local imports here as model importing causes a reference to
# app.config to be invoked and we need the current_app to have been setup
@@ -527,10 +519,10 @@ class SupersetAppInitializer:
self.flask_app.url_map.converters["regex"] = RegexConverter
self.flask_app.url_map.converters["object_type"] = ObjectTypeConverter
- def configure_jinja_context(self):
+ def configure_jinja_context(self) -> None:
jinja_context_manager.init_app(self.flask_app)
- def configure_middlewares(self):
+ def configure_middlewares(self) -> None:
if self.config["ENABLE_CORS"]:
from flask_cors import CORS
@@ -539,24 +531,28 @@ class SupersetAppInitializer:
if self.config["ENABLE_PROXY_FIX"]:
from werkzeug.middleware.proxy_fix import ProxyFix
- self.flask_app.wsgi_app = ProxyFix(
+ self.flask_app.wsgi_app = ProxyFix( # type: ignore
self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"]
)
if self.config["ENABLE_CHUNK_ENCODING"]:
class ChunkedEncodingFix: # pylint: disable=too-few-public-methods
- def __init__(self, app):
+ def __init__(self, app: Flask) -> None:
self.app = app
- def __call__(self, environ, start_response):
+ def __call__(
+ self, environ: Dict[str, Any], start_response: Callable
+ ) -> Any:
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked":
environ["wsgi.input_terminated"] = True
return self.app(environ, start_response)
- self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app)
+ self.flask_app.wsgi_app = ChunkedEncodingFix( # type: ignore
+ self.flask_app.wsgi_app # type: ignore
+ )
if self.config["UPLOAD_FOLDER"]:
try:
@@ -565,7 +561,9 @@ class SupersetAppInitializer:
pass
for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
- self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app)
+ self.flask_app.wsgi_app = middleware( # type: ignore
+ self.flask_app.wsgi_app
+ )
# Flask-Compress
if self.config["ENABLE_FLASK_COMPRESS"]:
@@ -574,27 +572,27 @@ class SupersetAppInitializer:
if self.config["TALISMAN_ENABLED"]:
talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"])
- def configure_logging(self):
+ def configure_logging(self) -> None:
self.config["LOGGING_CONFIGURATOR"].configure_logging(
self.config, self.flask_app.debug
)
- def setup_db(self):
+ def setup_db(self) -> None:
db.init_app(self.flask_app)
- with self.flask_app.app_context():
+ with self.flask_app.app_context(): # type: ignore
pessimistic_connection_handling(db.engine)
migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations")
- def configure_wtf(self):
+ def configure_wtf(self) -> None:
if self.config["WTF_CSRF_ENABLED"]:
csrf = CSRFProtect(self.flask_app)
csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"]
for ex in csrf_exempt_list:
csrf.exempt(ex)
- def register_blueprints(self):
+ def register_blueprints(self) -> None:
for bp in self.config["BLUEPRINTS"]:
try:
logger.info(f"Registering blueprint: '{bp.name}'")
@@ -602,5 +600,5 @@ class SupersetAppInitializer:
except Exception: # pylint: disable=broad-except
logger.exception("blueprint registration failed")
- def setup_bundle_manifest(self):
+ def setup_bundle_manifest(self) -> None:
manifest_processor.init_app(self.flask_app)
diff --git a/superset/cli.py b/superset/cli.py
index 090e1fb..a136010 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -19,10 +19,11 @@ import logging
from datetime import datetime
from subprocess import Popen
from sys import stdout
-from typing import Type, Union
+from typing import Any, Dict, Type, Union
import click
import yaml
+from celery.utils.abstract import CallableTask
from colorama import Fore, Style
from flask import g
from flask.cli import FlaskGroup, with_appcontext
@@ -56,17 +57,17 @@ def normalize_token(token_name: str) -> str:
context_settings={"token_normalize_func": normalize_token},
)
@with_appcontext
-def superset():
+def superset() -> None:
"""This is a management script for the Superset application."""
@app.shell_context_processor
- def make_shell_context(): # pylint: disable=unused-variable
+ def make_shell_context() -> Dict[str, Any]: # pylint: disable=unused-variable
return dict(app=app, db=db)
@superset.command()
@with_appcontext
-def init():
+def init() -> None:
"""Inits the Superset application"""
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
@@ -75,7 +76,7 @@ def init():
@superset.command()
@with_appcontext
@click.option("--verbose", "-v", is_flag=True, help="Show extra information")
-def version(verbose):
+def version(verbose: bool) -> None:
"""Prints the current version number"""
print(Fore.BLUE + "-=" * 15)
print(
@@ -90,7 +91,9 @@ def version(verbose):
print(Style.RESET_ALL)
-def load_examples_run(load_test_data, only_metadata=False, force=False):
+def load_examples_run(
+ load_test_data: bool, only_metadata: bool = False, force: bool = False
+) -> None:
if only_metadata:
print("Loading examples metadata")
else:
@@ -160,7 +163,9 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
@click.option(
"--force", "-f", is_flag=True, help="Force load data even if table already exists"
)
-def load_examples(load_test_data, only_metadata=False, force=False):
+def load_examples(
+ load_test_data: bool, only_metadata: bool = False, force: bool = False
+) -> None:
"""Loads a set of Slices and Dashboards and a supporting dataset """
load_examples_run(load_test_data, only_metadata, force)
@@ -169,7 +174,7 @@ def load_examples(load_test_data, only_metadata=False, force=False):
@superset.command()
@click.option("--database_name", "-d", help="Database name to change")
@click.option("--uri", "-u", help="Database URI to change")
-def set_database_uri(database_name, uri):
+def set_database_uri(database_name: str, uri: str) -> None:
"""Updates a database connection URI """
utils.get_or_create_db(database_name, uri)
@@ -189,7 +194,7 @@ def set_database_uri(database_name, uri):
default=False,
help="Specify using 'merge' property during operation. " "Default value is False.",
)
-def refresh_druid(datasource, merge):
+def refresh_druid(datasource: str, merge: bool) -> None:
"""Refresh druid datasources"""
session = db.session()
from superset.connectors.druid.models import DruidCluster
@@ -226,7 +231,7 @@ def refresh_druid(datasource, merge):
default=None,
help="Specify the user name to assign dashboards to",
)
-def import_dashboards(path, recursive, username):
+def import_dashboards(path: str, recursive: bool, username: str) -> None:
"""Import dashboards from JSON"""
from superset.utils import dashboard_import_export
@@ -258,7 +263,7 @@ def import_dashboards(path, recursive, username):
@click.option(
"--print_stdout", "-p", is_flag=True, default=False, help="Print JSON to stdout"
)
-def export_dashboards(print_stdout, dashboard_file):
+def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
"""Export dashboards to JSON"""
from superset.utils import dashboard_import_export
@@ -295,7 +300,7 @@ def export_dashboards(print_stdout, dashboard_file):
default=False,
help="recursively search the path for yaml files",
)
-def import_datasources(path, sync, recursive):
+def import_datasources(path: str, sync: str, recursive: bool) -> None:
"""Import datasources from YAML"""
from superset.utils import dict_import_export
@@ -345,8 +350,11 @@ def import_datasources(path, sync, recursive):
help="Include fields containing defaults",
)
def export_datasources(
- print_stdout, datasource_file, back_references, include_defaults
-):
+ print_stdout: bool,
+ datasource_file: str,
+ back_references: bool,
+ include_defaults: bool,
+) -> None:
"""Export datasources to YAML"""
from superset.utils import dict_import_export
@@ -373,7 +381,7 @@ def export_datasources(
default=False,
help="Include parent back references",
)
-def export_datasource_schema(back_references):
+def export_datasource_schema(back_references: bool) -> None:
"""Export datasource YAML schema to stdout"""
from superset.utils import dict_import_export
@@ -383,7 +391,7 @@ def export_datasource_schema(back_references):
@superset.command()
@with_appcontext
-def update_datasources_cache():
+def update_datasources_cache() -> None:
"""Refresh sqllab datasources cache"""
from superset.models.core import Database
@@ -406,7 +414,7 @@ def update_datasources_cache():
@click.option(
"--workers", "-w", type=int, help="Number of celery server workers to fire up"
)
-def worker(workers):
+def worker(workers: int) -> None:
"""Starts a Superset worker for async SQL query execution."""
logger.info(
"The 'superset worker' command is deprecated. Please use the 'celery "
@@ -431,7 +439,7 @@ def worker(workers):
@click.option(
"-a", "--address", default="localhost", help="Address on which to run the service"
)
-def flower(port, address):
+def flower(port: int, address: str) -> None:
"""Runs a Celery Flower web server
Celery Flower is a UI to monitor the Celery operation on a given
@@ -487,7 +495,7 @@ def compute_thumbnails(
charts_only: bool,
force: bool,
model_id: int,
-):
+) -> None:
"""Compute thumbnails"""
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -500,8 +508,8 @@ def compute_thumbnails(
friendly_type: str,
model_cls: Union[Type[Dashboard], Type[Slice]],
model_id: int,
- compute_func,
- ):
+ compute_func: CallableTask,
+ ) -> None:
query = db.session.query(model_cls)
if model_id:
query = query.filter(model_cls.id.in_(model_id))
@@ -528,7 +536,7 @@ def compute_thumbnails(
@superset.command()
@with_appcontext
-def load_test_users():
+def load_test_users() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
@@ -538,7 +546,7 @@ def load_test_users():
load_test_users_run()
-def load_test_users_run():
+def load_test_users_run() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
@@ -583,7 +591,7 @@ def load_test_users_run():
@superset.command()
@with_appcontext
-def sync_tags():
+def sync_tags() -> None:
"""Rebuilds special tags (owner, type, favorited by)."""
# pylint: disable=no-member
metadata = Model.metadata
diff --git a/superset/config.py b/superset/config.py
index 738c251..35dbbf8 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -28,8 +28,9 @@ import os
import sys
from collections import OrderedDict
from datetime import date
-from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
+from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
+from cachelib.base import BaseCache
from celery.schedules import crontab
from dateutil import tz
from flask_appbuilder.security.manager import AUTH_DB
@@ -78,7 +79,7 @@ PACKAGE_JSON_FILE = os.path.join(BASE_DIR, "static", "assets", "package.json")
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
-def _try_json_readversion(filepath):
+def _try_json_readversion(filepath: str) -> Optional[str]:
try:
with open(filepath, "r") as f:
return json.load(f).get("version")
@@ -86,7 +87,9 @@ def _try_json_readversion(filepath):
return None
-def _try_json_readsha(filepath, length): # pylint: disable=unused-argument
+def _try_json_readsha( # pylint: disable=unused-argument
+ filepath: str, length: int
+) -> Optional[str]:
try:
with open(filepath, "r") as f:
return json.load(f).get("GIT_SHA")[:length]
@@ -453,6 +456,7 @@ BACKUP_COUNT = 30
# user=None,
# client=None,
# security_manager=None,
+# log_params=None,
# ):
# pass
QUERY_LOGGER = None
@@ -578,10 +582,9 @@ SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
Callable[["Database", "models.User", str, str], str]
] = None
-# An instantiated derivative of cachelib.base.BaseCache
-# if enabled, it can be used to store the results of long-running queries
+# If enabled, it can be used to store the results of long-running queries
# in SQL Lab by using the "Run Async" button/feature
-RESULTS_BACKEND = None
+RESULTS_BACKEND: Optional[BaseCache] = None
# Use PyArrow and MessagePack for async query results serialization,
# rather than JSON. This feature requires additional testing from the
@@ -604,7 +607,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[
# The namespace within hive where the tables created from
# uploading CSVs will be stored.
-UPLOADED_CSV_HIVE_NAMESPACE = None
+UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_csv_upload
@@ -614,7 +617,7 @@ UPLOADED_CSV_HIVE_NAMESPACE = None
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
["Database", "models.User"], List[str]
] = lambda database, user: [
- UPLOADED_CSV_HIVE_NAMESPACE # type: ignore
+ UPLOADED_CSV_HIVE_NAMESPACE
] if UPLOADED_CSV_HIVE_NAMESPACE else []
# A dictionary of items that gets merged into the Jinja context for
@@ -628,7 +631,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
# dictionary, which means the existing keys get overwritten by the content of this
# dictionary. The customized addons don't necessarily need to use jinjia templating
# language. This allows you to define custom logic to process macro template.
-CUSTOM_TEMPLATE_PROCESSORS = {} # type: Dict[str, BaseTemplateProcessor]
+CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
# Roles that are controlled by the API / Superset and should not be changes
# by humans.
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 59ea042..51bd85f 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -32,7 +32,7 @@ class SupersetException(Exception):
super().__init__(self.message)
@property
- def exception(self):
+ def exception(self) -> Optional[Exception]:
return self._exception
diff --git a/superset/extensions.py b/superset/extensions.py
index c501eeb..f321046 100644
--- a/superset/extensions.py
+++ b/superset/extensions.py
@@ -20,10 +20,12 @@ import random
import time
import uuid
from datetime import datetime, timedelta
-from typing import Dict, TYPE_CHECKING # pylint: disable=unused-import
+from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
import celery
+from cachelib.base import BaseCache
from dateutil.relativedelta import relativedelta
+from flask import Flask
from flask_appbuilder import AppBuilder, SQLA
from flask_migrate import Migrate
from flask_talisman import Talisman
@@ -32,7 +34,6 @@ from werkzeug.local import LocalProxy
from superset.utils.cache_manager import CacheManager
from superset.utils.feature_flag_manager import FeatureFlagManager
-# Avoid circular import
if TYPE_CHECKING:
from superset.jinja_context import ( # pylint: disable=unused-import
BaseTemplateProcessor,
@@ -49,18 +50,18 @@ class JinjaContextManager:
"timedelta": timedelta,
"uuid": uuid,
}
- self._template_processors = {} # type: Dict[str, BaseTemplateProcessor]
+ self._template_processors: Dict[str, Type["BaseTemplateProcessor"]] = {}
- def init_app(self, app):
+ def init_app(self, app: Flask) -> None:
self._base_context.update(app.config["JINJA_CONTEXT_ADDONS"])
self._template_processors.update(app.config["CUSTOM_TEMPLATE_PROCESSORS"])
@property
- def base_context(self):
+ def base_context(self) -> Dict[str, Any]:
return self._base_context
@property
- def template_processors(self):
+ def template_processors(self) -> Dict[str, Type["BaseTemplateProcessor"]]:
return self._template_processors
@@ -69,35 +70,35 @@ class ResultsBackendManager:
self._results_backend = None
self._use_msgpack = False
- def init_app(self, app):
- self._results_backend = app.config.get("RESULTS_BACKEND")
- self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK")
+ def init_app(self, app: Flask) -> None:
+ self._results_backend = app.config["RESULTS_BACKEND"]
+ self._use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
@property
- def results_backend(self):
+ def results_backend(self) -> Optional[BaseCache]:
return self._results_backend
@property
- def should_use_msgpack(self):
+ def should_use_msgpack(self) -> bool:
return self._use_msgpack
class UIManifestProcessor:
def __init__(self, app_dir: str) -> None:
- self.app = None
- self.manifest: dict = {}
+ self.app: Optional[Flask] = None
+ self.manifest: Dict[str, Dict[str, List[str]]] = {}
self.manifest_file = f"{app_dir}/static/assets/manifest.json"
- def init_app(self, app):
+ def init_app(self, app: Flask) -> None:
self.app = app
# Preload the cache
self.parse_manifest_json()
@app.context_processor
- def get_manifest(): # pylint: disable=unused-variable
+ def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable
loaded_chunks = set()
- def get_files(bundle, asset_type="js"):
+ def get_files(bundle: str, asset_type: str = "js") -> List[str]:
files = self.get_manifest_files(bundle, asset_type)
filtered_files = [f for f in files if f not in loaded_chunks]
for f in filtered_files:
@@ -109,18 +110,18 @@ class UIManifestProcessor:
css_manifest=lambda bundle: get_files(bundle, "css"),
)
- def parse_manifest_json(self):
+ def parse_manifest_json(self) -> None:
try:
with open(self.manifest_file, "r") as f:
- # the manifest includes non-entry files
- # we only need entries in templates
+ # the manifest includes non-entry files we only need entries in
+ # templates
full_manifest = json.load(f)
self.manifest = full_manifest.get("entrypoints", {})
except Exception: # pylint: disable=broad-except
pass
- def get_manifest_files(self, bundle, asset_type):
- if self.app.debug:
+ def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]:
+ if self.app and self.app.debug:
self.parse_manifest_json()
return self.manifest.get(bundle, {}).get(asset_type, [])
@@ -133,7 +134,7 @@ db = SQLA()
_event_logger: dict = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
-jinja_context_manager = JinjaContextManager() # type: JinjaContextManager
+jinja_context_manager = JinjaContextManager()
manifest_processor = UIManifestProcessor(APP_DIR)
migrate = Migrate()
results_backend_manager = ResultsBackendManager()
diff --git a/superset/forms.py b/superset/forms.py
index 175903a..4ba3ca2 100644
--- a/superset/forms.py
+++ b/superset/forms.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Contains the logic to create cohesive forms on the explore view"""
-from typing import List # pylint: disable=unused-import
+from typing import Any, List, Optional
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from wtforms import Field
@@ -25,24 +25,24 @@ class CommaSeparatedListField(Field):
widget = BS3TextFieldWidget()
data: List[str] = []
- def _value(self):
+ def _value(self) -> str:
if self.data:
- return u", ".join(self.data)
+ return ", ".join(self.data)
- return u""
+ return ""
- def process_formdata(self, valuelist):
+ def process_formdata(self, valuelist: List[str]) -> None:
if valuelist:
self.data = [x.strip() for x in valuelist[0].split(",")]
else:
self.data = []
-def filter_not_empty_values(value):
+def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]:
"""Returns a list of non empty values or None"""
- if not value:
+ if not values:
return None
- data = [x for x in value if x]
+ data = [value for value in values if value]
if not data:
return None
return data
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index e1a10cd..95ee723 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -17,7 +17,7 @@
"""Defines the templating context for SQL Lab"""
import inspect
import re
-from typing import Any, List, Optional, Tuple, TYPE_CHECKING
+from typing import Any, cast, List, Optional, Tuple, TYPE_CHECKING
from flask import g, request
from jinja2.sandbox import SandboxedEnvironment
@@ -207,7 +207,7 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
def __init__(
self,
- database: Optional["Database"] = None,
+ database: "Database",
query: Optional["Query"] = None,
table: Optional["SqlaTable"] = None,
extra_cache_keys: Optional[List[Any]] = None,
@@ -266,7 +266,7 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
schema, table_name = table_name.split(".")
return table_name, schema
- def first_latest_partition(self, table_name: str) -> str:
+ def first_latest_partition(self, table_name: str) -> Optional[str]:
"""
Gets the first value in the array of all latest partitions
@@ -275,9 +275,10 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
:raises IndexError: If no partition exists
"""
- return self.latest_partitions(table_name)[0]
+ latest_partitions = self.latest_partitions(table_name)
+ return latest_partitions[0] if latest_partitions else None
- def latest_partitions(self, table_name: str) -> List[str]:
+ def latest_partitions(self, table_name: str) -> Optional[List[str]]:
"""
Gets the array of all latest partitions
@@ -285,16 +286,21 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
:return: the latest partition array
"""
+ from superset.db_engine_specs.presto import PrestoEngineSpec
+
table_name, schema = self._schema_table(table_name, self.schema)
- assert self.database
- return self.database.db_engine_spec.latest_partition( # type: ignore
+ return cast(PrestoEngineSpec, self.database.db_engine_spec).latest_partition(
table_name, schema, self.database
)[1]
- def latest_sub_partition(self, table_name, **kwargs):
+ def latest_sub_partition(self, table_name: str, **kwargs: Any) -> Any:
table_name, schema = self._schema_table(table_name, self.schema)
- assert self.database
- return self.database.db_engine_spec.latest_sub_partition(
+
+ from superset.db_engine_specs.presto import PrestoEngineSpec
+
+ return cast(
+ PrestoEngineSpec, self.database.db_engine_spec
+ ).latest_sub_partition(
table_name=table_name, schema=schema, database=self.database, **kwargs
)
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index ab952db..3ba1e3a 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -19,7 +19,7 @@ import uuid
from contextlib import closing
from datetime import datetime
from sys import getsizeof
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
import backoff
import msgpack
@@ -27,9 +27,10 @@ import pyarrow as pa
import simplejson as json
import sqlalchemy
from celery.exceptions import SoftTimeLimitExceeded
+from celery.task.base import Task
from contextlib2 import contextmanager
from flask_babel import lazy_gettext as _
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool
from superset import (
@@ -77,7 +78,9 @@ class SqlLabTimeoutException(SqlLabException):
pass
-def handle_query_error(msg, query, session, payload=None):
+def handle_query_error(
+ msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None
+) -> Dict[str, Any]:
"""Local method handling error while processing the SQL"""
payload = payload or {}
troubleshooting_link = config["TROUBLESHOOTING_LINK"]
@@ -91,14 +94,14 @@ def handle_query_error(msg, query, session, payload=None):
return payload
-def get_query_backoff_handler(details):
+def get_query_backoff_handler(details: Dict[Any, Any]) -> None:
query_id = details["kwargs"]["query_id"]
logger.error(f"Query with id `{query_id}` could not be retrieved")
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
-def get_query_giveup_handler(_):
+def get_query_giveup_handler(_: Any) -> None:
stats_logger.incr("error_failed_at_getting_orm_query")
@@ -110,7 +113,7 @@ def get_query_giveup_handler(_):
on_giveup=get_query_giveup_handler,
max_tries=5,
)
-def get_query(query_id, session):
+def get_query(query_id: int, session: Session) -> Query:
"""attempts to get the query and retry if it cannot"""
try:
return session.query(Query).filter_by(id=query_id).one()
@@ -119,7 +122,7 @@ def get_query(query_id, session):
@contextmanager
-def session_scope(nullpool):
+def session_scope(nullpool: bool) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations."""
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
if "sqlite" in database_uri:
@@ -154,16 +157,16 @@ def session_scope(nullpool):
soft_time_limit=SQLLAB_TIMEOUT,
)
def get_sql_results( # pylint: disable=too-many-arguments
- ctask,
- query_id,
- rendered_query,
- return_results=True,
- store_results=False,
- user_name=None,
- start_time=None,
- expand_data=False,
- log_params=None,
-):
+ ctask: Task,
+ query_id: int,
+ rendered_query: str,
+ return_results: bool = True,
+ store_results: bool = False,
+ user_name: Optional[str] = None,
+ start_time: Optional[float] = None,
+ expand_data: bool = False,
+ log_params: Optional[Dict[str, Any]] = None,
+) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
@@ -188,7 +191,14 @@ def get_sql_results( # pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments
-def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_params):
+def execute_sql_statement(
+ sql_statement: str,
+ query: Query,
+ user_name: Optional[str],
+ session: Session,
+ cursor: Any,
+ log_params: Optional[Dict[str, Any]],
+) -> SupersetResultSet:
"""Executes a single SQL statement"""
database = query.database
db_engine_spec = database.db_engine_spec
@@ -275,7 +285,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
def _serialize_payload(
- payload: dict, use_msgpack: Optional[bool] = False
+ payload: Dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:
logger.debug(f"Serializing to msgpack: {use_msgpack}")
if use_msgpack:
@@ -321,24 +331,24 @@ def _serialize_and_expand_data(
return (data, selected_columns, all_columns, expanded_columns)
-def execute_sql_statements(
- query_id,
- rendered_query,
- return_results=True,
- store_results=False,
- user_name=None,
- session=None,
- start_time=None,
- expand_data=False,
- log_params=None,
-): # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
+def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
+ query_id: int,
+ rendered_query: str,
+ return_results: bool,
+ store_results: bool,
+ user_name: Optional[str],
+ session: Session,
+ start_time: Optional[float],
+ expand_data: bool,
+ log_params: Optional[Dict[str, Any]],
+) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
if store_results and start_time:
# only asynchronous queries
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
query = get_query(query_id, session)
- payload = dict(query_id=query_id)
+ payload: Dict[str, Any] = dict(query_id=query_id)
database = query.database
db_engine_spec = database.db_engine_spec
db_engine_spec.patch()
@@ -406,7 +416,7 @@ def execute_sql_statements(
)
query.end_time = now_as_float()
- use_arrow_data = store_results and results_backend_use_msgpack
+ use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
result_set, db_engine_spec, use_arrow_data, expand_data
)
@@ -432,7 +442,7 @@ def execute_sql_statements(
"sqllab.query.results_backend_write_serialization", stats_logger
):
serialized_payload = _serialize_payload(
- payload, results_backend_use_msgpack
+ payload, cast(bool, results_backend_use_msgpack)
)
cache_timeout = database.cache_timeout
if cache_timeout is None:
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index be9cf10..3e50386 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -158,7 +158,7 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
- def _process_tokenlist(self, token_list: TokenList):
+ def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
@@ -204,7 +204,9 @@ class ParsedQuery:
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql
- def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
+ def _extract_from_token( # pylint: disable=too-many-branches
+ self, token: Token
+ ) -> None:
"""
Populate self._tables from token
diff --git a/superset/stats_logger.py b/superset/stats_logger.py
index 37fe3d3..75cfd8a 100644
--- a/superset/stats_logger.py
+++ b/superset/stats_logger.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
+from typing import Optional
from colorama import Fore, Style
@@ -40,7 +41,7 @@ class BaseStatsLogger:
"""Decrement a counter"""
raise NotImplementedError()
- def timing(self, key, value: float) -> None:
+ def timing(self, key: str, value: float) -> None:
raise NotImplementedError()
def gauge(self, key: str) -> None:
@@ -49,18 +50,18 @@ class BaseStatsLogger:
class DummyStatsLogger(BaseStatsLogger):
- def incr(self, key):
+ def incr(self, key: str) -> None:
logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL)
- def decr(self, key):
+ def decr(self, key: str) -> None:
logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL))
- def timing(self, key, value):
+ def timing(self, key: str, value: float) -> None:
logger.debug(
(Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL)
)
- def gauge(self, key):
+ def gauge(self, key: str) -> None:
logger.debug(
(Fore.CYAN + "[stats_logger] (gauge) " + f"{key}" + Style.RESET_ALL)
)
@@ -71,8 +72,12 @@ try:
class StatsdStatsLogger(BaseStatsLogger):
def __init__( # pylint: disable=super-init-not-called
- self, host="localhost", port=8125, prefix="superset", statsd_client=None
- ):
+ self,
+ host: str = "localhost",
+ port: int = 8125,
+ prefix: str = "superset",
+ statsd_client: Optional[StatsClient] = None,
+ ) -> None:
"""
Initializes from either params or a supplied, pre-constructed statsd client.
@@ -84,16 +89,16 @@ try:
else:
self.client = StatsClient(host=host, port=port, prefix=prefix)
- def incr(self, key):
+ def incr(self, key: str) -> None:
self.client.incr(key)
- def decr(self, key):
+ def decr(self, key: str) -> None:
self.client.decr(key)
- def timing(self, key, value):
+ def timing(self, key: str, value: float) -> None:
self.client.timing(key, value)
- def gauge(self, key):
+ def gauge(self, key: str) -> None:
# pylint: disable=no-value-for-parameter
self.client.gauge(key)
diff --git a/superset/typing.py b/superset/typing.py
index 09a3393..e238000 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -33,6 +33,7 @@ Granularity = Union[str, Dict[str, Union[str, float]]]
Metric = Union[Dict[str, str], str]
QueryObjectDict = Dict[str, Any]
VizData = Optional[Union[List[Any], Dict[Any, Any]]]
+VizPayload = Dict[str, Any]
# Flask response.
Base = Union[bytes, str]
diff --git a/superset/viz.py b/superset/viz.py
index bf3c110..2e38bf2 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -31,7 +31,7 @@ import uuid
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from itertools import product
-from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
+from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import dataclasses
import geohash
@@ -55,7 +55,7 @@ from superset.exceptions import (
SpatialException,
)
from superset.models.helpers import QueryResult
-from superset.typing import VizData
+from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
from superset.utils.core import (
DTTM_ALIAS,
@@ -101,7 +101,7 @@ class BaseViz:
datasource: "BaseDatasource",
form_data: Dict[str, Any],
force: bool = False,
- ):
+ ) -> None:
if not datasource:
raise Exception(_("Viz is missing a datasource"))
@@ -134,7 +134,7 @@ class BaseViz:
self.process_metrics()
- def process_metrics(self):
+ def process_metrics(self) -> None:
# metrics in TableViz is order sensitive, so metric_dict should be
# OrderedDict
self.metric_dict = OrderedDict()
@@ -153,8 +153,10 @@ class BaseViz:
self.metric_labels = list(self.metric_dict.keys())
@staticmethod
- def handle_js_int_overflow(data):
- for d in data.get("records", dict()):
+ def handle_js_int_overflow(
+ data: Dict[str, List[Dict[str, Any]]]
+ ) -> Dict[str, List[Dict[str, Any]]]:
+ for d in data.get("records", {}):
for k, v in list(d.items()):
if isinstance(v, int):
# if an int is too big for Java Script to handle
@@ -163,7 +165,7 @@ class BaseViz:
d[k] = str(v)
return data
- def run_extra_queries(self):
+ def run_extra_queries(self) -> None:
"""Lifecycle method to use when more than one query is needed
In rare-ish cases, a visualization may need to execute multiple
@@ -186,7 +188,7 @@ class BaseViz:
"""
pass
- def apply_rolling(self, df):
+ def apply_rolling(self, df: pd.DataFrame) -> pd.DataFrame:
fd = self.form_data
rolling_type = fd.get("rolling_type")
rolling_periods = int(fd.get("rolling_periods") or 0)
@@ -206,7 +208,7 @@ class BaseViz:
df = df[min_periods:]
return df
- def get_samples(self):
+ def get_samples(self) -> List[Dict[str, Any]]:
query_obj = self.query_obj()
query_obj.update(
{
@@ -219,7 +221,7 @@ class BaseViz:
df = self.get_df(query_obj)
return df.to_dict(orient="records")
- def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
query_obj = self.query_obj()
@@ -281,19 +283,19 @@ class BaseViz:
df.replace([np.inf, -np.inf], np.nan, inplace=True)
return df
- def df_metrics_to_num(self, df):
+ def df_metrics_to_num(self, df: pd.DataFrame) -> None:
"""Converting metrics to numeric when pandas.read_sql cannot"""
metrics = self.metric_labels
for col, dtype in df.dtypes.items():
if dtype.type == np.object_ and col in metrics:
df[col] = pd.to_numeric(df[col], errors="coerce")
- def process_query_filters(self):
+ def process_query_filters(self) -> None:
utils.convert_legacy_filters_into_adhoc(self.form_data)
merge_extra_filters(self.form_data)
utils.split_adhoc_filters_into_base_filters(self.form_data)
- def query_obj(self) -> Dict[str, Any]:
+ def query_obj(self) -> QueryObjectDict:
"""Building a query object"""
form_data = self.form_data
self.process_query_filters()
@@ -362,9 +364,9 @@ class BaseViz:
return d
@property
- def cache_timeout(self):
+ def cache_timeout(self) -> int:
if self.form_data.get("cache_timeout") is not None:
- return int(self.form_data.get("cache_timeout"))
+ return int(self.form_data["cache_timeout"])
if self.datasource.cache_timeout is not None:
return self.datasource.cache_timeout
if (
@@ -374,12 +376,12 @@ class BaseViz:
return self.datasource.database.cache_timeout
return config["CACHE_DEFAULT_TIMEOUT"]
- def get_json(self):
+ def get_json(self) -> str:
return json.dumps(
self.get_payload(), default=utils.json_int_dttm_ser, ignore_nan=True
)
- def cache_key(self, query_obj, **extra):
+ def cache_key(self, query_obj: QueryObjectDict, **extra: Any) -> str:
"""
The cache key is made out of the key/values in `query_obj`, plus any
other key/values in `extra`.
@@ -410,7 +412,7 @@ class BaseViz:
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
- def get_payload(self, query_obj=None):
+ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload:
"""Returns a payload of metadata and data"""
self.run_extra_queries()
payload = self.get_df_payload(query_obj)
@@ -422,7 +424,9 @@ class BaseViz:
del payload["df"]
return payload
- def get_df_payload(self, query_obj=None, **kwargs):
+ def get_df_payload(
+ self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
+ ) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
if not query_obj:
query_obj = self.query_obj()
@@ -512,21 +516,21 @@ class BaseViz:
"rowcount": len(df.index) if df is not None else 0,
}
- def json_dumps(self, obj, sort_keys=False):
+ def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
- def payload_json_and_has_error(self, payload):
+ def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
has_error = (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
- or len(payload.get("errors")) > 0
+ or len(payload.get("errors") or []) > 0
)
return self.json_dumps(payload), has_error
@property
- def data(self):
+ def data(self) -> Dict[str, Any]:
"""This is the data object serialized to the js layer"""
content = {
"form_data": self.form_data,
@@ -536,7 +540,7 @@ class BaseViz:
}
return content
- def get_csv(self):
+ def get_csv(self) -> Optional[str]:
df = self.get_df()
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
@@ -545,7 +549,7 @@ class BaseViz:
return df.to_dict(orient="records")
@property
- def json_data(self):
+ def json_data(self) -> str:
return json.dumps(self.data)
@@ -559,7 +563,7 @@ class TableViz(BaseViz):
is_timeseries = False
enforce_numerical_metrics = False
- def should_be_timeseries(self):
+ def should_be_timeseries(self) -> bool:
fd = self.form_data
# TODO handle datasource-type-specific code in datasource
conditions_met = (fd.get("granularity") and fd.get("granularity") != "all") or (
@@ -569,9 +573,9 @@ class TableViz(BaseViz):
raise QueryObjectValidationError(
_("Pick a granularity in the Time section or " "uncheck 'Include Time'")
)
- return fd.get("include_time")
+ return bool(fd.get("include_time"))
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
@@ -660,7 +664,7 @@ class TableViz(BaseViz):
return data
- def json_dumps(self, obj, sort_keys=False):
+ def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=utils.json_iso_dttm_ser, sort_keys=sort_keys, ignore_nan=True
)
@@ -675,14 +679,14 @@ class TimeTableViz(BaseViz):
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
if not fd.get("metrics"):
raise QueryObjectValidationError(_("Pick at least one metric"))
- if fd.get("groupby") and len(fd.get("metrics")) > 1:
+ if fd.get("groupby") and len(fd["metrics"]) > 1:
raise QueryObjectValidationError(
_("When using 'Group By' you are limited to use a single metric")
)
@@ -694,7 +698,7 @@ class TimeTableViz(BaseViz):
fd = self.form_data
columns = None
- values = self.metric_labels
+ values: Union[List[str], str] = self.metric_labels
if fd.get("groupby"):
values = self.metric_labels[0]
columns = fd.get("groupby")
@@ -717,7 +721,7 @@ class PivotTableViz(BaseViz):
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
groupby = self.form_data.get("groupby")
columns = self.form_data.get("columns")
@@ -798,10 +802,10 @@ class MarkupViz(BaseViz):
verbose_name = _("Markup")
is_timeseries = False
- def query_obj(self):
- return None
+ def query_obj(self) -> QueryObjectDict:
+ return {}
- def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
return pd.DataFrame()
def get_data(self, df: pd.DataFrame) -> VizData:
@@ -832,7 +836,7 @@ class WordCloudViz(BaseViz):
verbose_name = _("Word Cloud")
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
d["groupby"] = [self.form_data.get("series")]
return d
@@ -847,7 +851,7 @@ class TreemapViz(BaseViz):
credits = '<a href="https://d3js.org">d3.js</a>'
is_timeseries = False
- def _nest(self, metric, df):
+ def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
nlevels = df.index.nlevels
if nlevels == 1:
result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])]
@@ -927,7 +931,7 @@ class CalHeatmapViz(BaseViz):
"range": range_,
}
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
d["metrics"] = fd.get("metrics")
@@ -953,19 +957,21 @@ class BoxPlotViz(NVD3Viz):
sort_series = False
is_timeseries = True
- def to_series(self, df, classed="", title_suffix=""):
+ def to_series(
+ self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
+ ) -> List[Dict[str, Any]]:
label_sep = " - "
chart_data = []
for index_value, row in zip(df.index, df.to_dict(orient="records")):
if isinstance(index_value, tuple):
index_value = label_sep.join(index_value)
- boxes = defaultdict(dict)
+ boxes: Dict[str, Dict[str, Any]] = defaultdict(dict)
for (label, key), value in row.items():
if key == "nanmedian":
key = "Q2"
boxes[label][key] = value
for label, box in boxes.items():
- if len(self.form_data.get("metrics")) > 1:
+ if len(self.form_data["metrics"]) > 1:
# need to render data labels with metrics
chart_label = label_sep.join([index_value, label])
else:
@@ -980,46 +986,45 @@ class BoxPlotViz(NVD3Viz):
form_data = self.form_data
# conform to NVD3 names
- def Q1(series): # need to be named functions - can't use lambdas
+ def Q1(series: pd.Series) -> float:
+ # need to be named functions - can't use lambdas
return np.nanpercentile(series, 25)
- def Q3(series):
+ def Q3(series: pd.Series) -> float:
return np.nanpercentile(series, 75)
whisker_type = form_data.get("whisker_options")
if whisker_type == "Tukey":
- def whisker_high(series):
+ def whisker_high(series: pd.Series) -> float:
upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series))
return series[series <= upper_outer_lim].max()
- def whisker_low(series):
+ def whisker_low(series: pd.Series) -> float:
lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series))
return series[series >= lower_outer_lim].min()
elif whisker_type == "Min/max (no outliers)":
- def whisker_high(series):
+ def whisker_high(series: pd.Series) -> float:
return series.max()
- def whisker_low(series):
+ def whisker_low(series: pd.Series) -> float:
return series.min()
elif " percentiles" in whisker_type: # type: ignore
- low, high = whisker_type.replace(" percentiles", "").split( # type: ignore
- "/"
- )
+ low, high = cast(str, whisker_type).replace(" percentiles", "").split("/")
- def whisker_high(series):
+ def whisker_high(series: pd.Series) -> float:
return np.nanpercentile(series, int(high))
- def whisker_low(series):
+ def whisker_low(series: pd.Series) -> float:
return np.nanpercentile(series, int(low))
else:
raise ValueError("Unknown whisker type: {}".format(whisker_type))
- def outliers(series):
+ def outliers(series: pd.Series) -> Set[float]:
above = series[series > whisker_high(series)]
below = series[series < whisker_low(series)]
# pandas sometimes doesn't like getting lists back here
@@ -1039,7 +1044,7 @@ class BubbleViz(NVD3Viz):
verbose_name = _("Bubble Chart")
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
form_data = self.form_data
d = super().query_obj()
d["groupby"] = [form_data.get("entity")]
@@ -1090,7 +1095,7 @@ class BulletViz(NVD3Viz):
verbose_name = _("Bullet Chart")
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
form_data = self.form_data
d = super().query_obj()
self.metric = form_data["metric"]
@@ -1117,7 +1122,7 @@ class BigNumberViz(BaseViz):
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metric = self.form_data.get("metric")
if not metric:
@@ -1151,7 +1156,7 @@ class BigNumberTotalViz(BaseViz):
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
metric = self.form_data.get("metric")
if not metric:
@@ -1174,7 +1179,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
is_timeseries = True
pivot_fill_value: Optional[int] = None
- def to_series(self, df, classed="", title_suffix=""):
+ def to_series(
+ self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
+ ) -> List[Dict[str, Any]]:
cols = []
for col in df.columns:
if col == "":
@@ -1191,6 +1198,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
ys = series[name]
if df[name].dtype.kind not in "biufc":
continue
+ series_title: Union[List[str], str, Tuple[str, ...]]
if isinstance(name, list):
series_title = [str(title) for title in name]
elif isinstance(name, tuple):
@@ -1207,7 +1215,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
if title_suffix:
if isinstance(series_title, str):
series_title = (series_title, title_suffix)
- elif isinstance(series_title, (list, tuple)):
+ elif isinstance(series_title, list):
+ series_title = series_title + [title_suffix]
+ elif isinstance(series_title, tuple):
series_title = series_title + (title_suffix,)
values = []
@@ -1274,7 +1284,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
return df
- def run_extra_queries(self):
+ def run_extra_queries(self) -> None:
fd = self.form_data
time_compare = fd.get("time_compare") or []
@@ -1364,8 +1374,8 @@ class MultiLineViz(NVD3Viz):
is_timeseries = True
- def query_obj(self):
- return None
+ def query_obj(self) -> QueryObjectDict:
+ return {}
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
@@ -1394,7 +1404,7 @@ class NVD3DualLineViz(NVD3Viz):
sort_series = False
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
m1 = self.form_data.get("metric")
m2 = self.form_data.get("metric_2")
@@ -1409,7 +1419,7 @@ class NVD3DualLineViz(NVD3Viz):
)
return d
- def to_series(self, df, classed=""):
+ def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]:
cols = []
for col in df.columns:
if col == "":
@@ -1421,7 +1431,7 @@ class NVD3DualLineViz(NVD3Viz):
df.columns = cols
series = df.to_dict("series")
chart_data = []
- metrics = [self.form_data.get("metric"), self.form_data.get("metric_2")]
+ metrics = [self.form_data["metric"], self.form_data["metric_2"]]
for i, m in enumerate(metrics):
m = utils.get_metric_name(m)
ys = series[m]
@@ -1476,7 +1486,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
sort_series = True
verbose_name = _("Time Series - Period Pivot")
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
d["metrics"] = [self.form_data.get("metric")]
return d
@@ -1561,7 +1571,7 @@ class HistogramViz(BaseViz):
verbose_name = _("Histogram")
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
"""Returns the query object for this visualization"""
d = super().query_obj()
d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"]))
@@ -1576,9 +1586,9 @@ class HistogramViz(BaseViz):
d["groupby"] = []
return d
- def labelify(self, keys, column):
+ def labelify(self, keys: Union[List[str], str], column: str) -> str:
if isinstance(keys, str):
- keys = (keys,)
+ keys = [keys]
# removing undesirable characters
labels = [re.sub(r"\W+", r"_", k) for k in keys]
if len(self.columns) > 1 or not self.groupby:
@@ -1617,7 +1627,7 @@ class DistributionBarViz(DistributionPieViz):
verbose_name = _("Distribution - Bar Chart")
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
if len(d["groupby"]) < len(fd.get("groupby") or []) + len(
@@ -1708,7 +1718,7 @@ class SunburstViz(BaseViz):
df = df[cols]
return df.to_numpy().tolist()
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
qry = super().query_obj()
fd = self.form_data
qry["metrics"] = [fd["metric"]]
@@ -1727,7 +1737,7 @@ class SankeyViz(BaseViz):
is_timeseries = False
credits = '<a href="https://www.npmjs.com/package/d3-sankey">d3-sankey on npm</a>'
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
qry = super().query_obj()
if len(qry["groupby"]) != 2:
raise QueryObjectValidationError(
@@ -1746,21 +1756,23 @@ class SankeyViz(BaseViz):
for row in recs:
hierarchy[row["source"]].add(row["target"])
- def find_cycle(g):
+ def find_cycle(g: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]:
"""Whether there's a cycle in a directed graph"""
path = set()
- def visit(vertex):
+ def visit(vertex: str) -> Optional[Tuple[str, str]]:
path.add(vertex)
for neighbour in g.get(vertex, ()):
if neighbour in path or visit(neighbour):
return (vertex, neighbour)
path.remove(vertex)
+ return None
for v in g:
cycle = visit(v)
if cycle:
return cycle
+ return None
cycle = find_cycle(hierarchy)
if cycle:
@@ -1782,7 +1794,7 @@ class DirectedForceViz(BaseViz):
credits = 'd3noob @<a href="http://bl.ocks.org/d3noob/5141278">bl.ocks.org</a>'
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
qry = super().query_obj()
if len(self.form_data["groupby"]) != 2:
raise QueryObjectValidationError(_("Pick exactly 2 columns to 'Group By'"))
@@ -1803,7 +1815,7 @@ class ChordViz(BaseViz):
credits = '<a href="https://github.com/d3/d3-chord">Bostock</a>'
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
qry = super().query_obj()
fd = self.form_data
qry["groupby"] = [fd.get("groupby"), fd.get("columns")]
@@ -1836,7 +1848,7 @@ class CountryMapViz(BaseViz):
is_timeseries = False
credits = "From bl.ocks.org By john-guerra"
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
qry = super().query_obj()
qry["metrics"] = [self.form_data["metric"]]
qry["groupby"] = [self.form_data["entity"]]
@@ -1863,7 +1875,7 @@ class WorldMapViz(BaseViz):
is_timeseries = False
credits = 'datamaps on <a href="https://www.npmjs.com/package/datamaps">npm</a>'
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
qry = super().query_obj()
qry["groupby"] = [self.form_data["entity"]]
return qry
@@ -1923,10 +1935,10 @@ class FilterBoxViz(BaseViz):
cache_type = "get_data"
filter_row_limit = 1000
- def query_obj(self):
- return None
+ def query_obj(self) -> QueryObjectDict:
+ return {}
- def run_extra_queries(self):
+ def run_extra_queries(self) -> None:
qry = super().query_obj()
filters = self.form_data.get("filter_configs") or []
qry["row_limit"] = self.filter_row_limit
@@ -1979,10 +1991,10 @@ class IFrameViz(BaseViz):
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = False
- def query_obj(self):
- return None
+ def query_obj(self) -> QueryObjectDict:
+ return {}
- def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
return pd.DataFrame()
def get_data(self, df: pd.DataFrame) -> VizData:
@@ -2005,7 +2017,7 @@ class ParallelCoordinatesViz(BaseViz):
)
is_timeseries = False
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
d["groupby"] = [fd.get("series")]
@@ -2027,7 +2039,7 @@ class HeatmapViz(BaseViz):
"bl.ocks.org</a>"
)
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
d["metrics"] = [fd.get("metric")]
@@ -2092,7 +2104,7 @@ class MapboxViz(BaseViz):
is_timeseries = False
credits = "<a href=https://www.mapbox.com/mapbox-gl-js/api/>Mapbox GL JS</a>"
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
fd = self.form_data
label_col = fd.get("mapbox_label")
@@ -2124,22 +2136,24 @@ class MapboxViz(BaseViz):
label_col
and len(label_col) >= 1
and label_col[0] != "count"
- and label_col[0] not in fd.get("groupby")
+ and label_col[0] not in fd["groupby"]
):
raise QueryObjectValidationError(
_("Choice of [Label] must be present in [Group By]")
)
- if fd.get("point_radius") != "Auto" and fd.get(
- "point_radius"
- ) not in fd.get("groupby"):
+ if (
+ fd.get("point_radius") != "Auto"
+ and fd.get("point_radius") not in fd["groupby"]
+ ):
raise QueryObjectValidationError(
_("Choice of [Point Radius] must be present in [Group By]")
)
- if fd.get("all_columns_x") not in fd.get("groupby") or fd.get(
- "all_columns_y"
- ) not in fd.get("groupby"):
+ if (
+ fd.get("all_columns_x") not in fd["groupby"]
+ or fd.get("all_columns_y") not in fd["groupby"]
+ ):
raise QueryObjectValidationError(
_(
"[Longitude] and [Latitude] columns must be present in "
@@ -2226,8 +2240,8 @@ class DeckGLMultiLayer(BaseViz):
is_timeseries = False
credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
- def query_obj(self):
- return None
+ def query_obj(self) -> QueryObjectDict:
+ return {}
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
@@ -2251,14 +2265,14 @@ class BaseDeckGLViz(BaseViz):
credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
spatial_control_keys: List[str] = []
- def get_metrics(self):
+ def get_metrics(self) -> List[str]:
self.metric = self.form_data.get("size")
return [self.metric] if self.metric else []
- def process_spatial_query_obj(self, key, group_by):
+ def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None:
group_by.extend(self.get_spatial_columns(key))
- def get_spatial_columns(self, key):
+ def get_spatial_columns(self, key: str) -> List[str]:
spatial = self.form_data.get(key)
if spatial is None:
raise ValueError(_("Bad spatial key"))
@@ -2269,9 +2283,10 @@ class BaseDeckGLViz(BaseViz):
return [spatial.get("lonlatCol")]
elif spatial.get("type") == "geohash":
return [spatial.get("geohashCol")]
+ return []
@staticmethod
- def parse_coordinates(s):
+ def parse_coordinates(s: Any) -> Optional[Tuple[float, float]]:
if not s:
return None
try:
@@ -2281,15 +2296,15 @@ class BaseDeckGLViz(BaseViz):
raise SpatialException(_("Invalid spatial point encountered: %s" % s))
@staticmethod
- def reverse_geohash_decode(geohash_code):
+ def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]:
lat, lng = geohash.decode(geohash_code)
return (lng, lat)
@staticmethod
- def reverse_latlong(df, key):
+ def reverse_latlong(df: pd.DataFrame, key: str) -> None:
df[key] = [tuple(reversed(o)) for o in df[key] if isinstance(o, (list, tuple))]
- def process_spatial_data_obj(self, key, df):
+ def process_spatial_data_obj(self, key: str, df: pd.DataFrame) -> pd.DataFrame:
spatial = self.form_data.get(key)
if spatial is None:
raise ValueError(_("Bad spatial key"))
@@ -2321,7 +2336,7 @@ class BaseDeckGLViz(BaseViz):
)
return df
- def add_null_filters(self):
+ def add_null_filters(self) -> None:
fd = self.form_data
spatial_columns = set()
for key in self.spatial_control_keys:
@@ -2339,7 +2354,7 @@ class BaseDeckGLViz(BaseViz):
filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""})
fd["adhoc_filters"].append(filter_)
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
fd = self.form_data
# add NULL filters
@@ -2347,16 +2362,16 @@ class BaseDeckGLViz(BaseViz):
self.add_null_filters()
d = super().query_obj()
- gb = []
+ gb: List[str] = []
for key in self.spatial_control_keys:
self.process_spatial_query_obj(key, gb)
if fd.get("dimension"):
- gb += [fd.get("dimension")]
+ gb += [fd["dimension"]]
if fd.get("js_columns"):
- gb += fd.get("js_columns")
+ gb += fd.get("js_columns") or []
metrics = self.get_metrics()
gb = list(set(gb))
if metrics:
@@ -2367,7 +2382,7 @@ class BaseDeckGLViz(BaseViz):
d["columns"] = gb
return d
- def get_js_columns(self, d):
+ def get_js_columns(self, d: Dict[str, Any]) -> Dict[str, Any]:
cols = self.form_data.get("js_columns") or []
return {col: d.get(col) for col in cols}
@@ -2393,7 +2408,7 @@ class BaseDeckGLViz(BaseViz):
"metricLabels": self.metric_labels,
}
- def get_properties(self, d):
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError()
@@ -2406,7 +2421,7 @@ class DeckScatterViz(BaseDeckGLViz):
spatial_control_keys = ["spatial"]
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
fd = self.form_data
self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
self.point_radius_fixed = fd.get("point_radius_fixed") or {
@@ -2415,19 +2430,21 @@ class DeckScatterViz(BaseDeckGLViz):
}
return super().query_obj()
- def get_metrics(self):
+ def get_metrics(self) -> List[str]:
self.metric = None
if self.point_radius_fixed.get("type") == "metric":
- self.metric = self.point_radius_fixed.get("value")
+ self.metric = self.point_radius_fixed["value"]
return [self.metric]
- return None
+ return []
- def get_properties(self, d):
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
return {
- "metric": d.get(self.metric_label),
+ "metric": d.get(self.metric_label) if self.metric_label else None,
"radius": self.fixed_value
if self.fixed_value
- else d.get(self.metric_label),
+ else d.get(self.metric_label)
+ if self.metric_label
+ else None,
"cat_color": d.get(self.dim) if self.dim else None,
"position": d.get("spatial"),
DTTM_ALIAS: d.get(DTTM_ALIAS),
@@ -2453,20 +2470,20 @@ class DeckScreengrid(BaseDeckGLViz):
spatial_control_keys = ["spatial"]
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
fd = self.form_data
- self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity")
+ self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
return super().query_obj()
- def get_properties(self, d):
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
return {
"position": d.get("spatial"),
- "weight": d.get(self.metric_label) or 1,
+ "weight": (d.get(self.metric_label) if self.metric_label else None) or 1,
"__timestamp": d.get(DTTM_ALIAS) or d.get("__time"),
}
def get_data(self, df: pd.DataFrame) -> VizData:
- self.metric_label = utils.get_metric_name(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
return super().get_data(df)
@@ -2478,15 +2495,18 @@ class DeckGrid(BaseDeckGLViz):
verbose_name = _("Deck.gl - 3D Grid")
spatial_control_keys = ["spatial"]
- def get_properties(self, d):
- return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "position": d.get("spatial"),
+ "weight": (d.get(self.metric_label) if self.metric_label else None) or 1,
+ }
def get_data(self, df: pd.DataFrame) -> VizData:
- self.metric_label = utils.get_metric_name(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
return super().get_data(df)
-def geohash_to_json(geohash_code):
+def geohash_to_json(geohash_code: str) -> List[List[float]]:
p = geohash.bbox(geohash_code)
return [
[p.get("w"), p.get("n")],
@@ -2511,9 +2531,9 @@ class DeckPathViz(BaseDeckGLViz):
"geohash": geohash_to_json,
}
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
fd = self.form_data
- self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity")
+ self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
d = super().query_obj()
self.metric = fd.get("metric")
line_col = fd.get("line_column")
@@ -2525,11 +2545,11 @@ class DeckPathViz(BaseDeckGLViz):
d["columns"].append(line_col)
return d
- def get_properties(self, d):
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
fd = self.form_data
- line_type = fd.get("line_type")
+ line_type = fd["line_type"]
deser = self.deser_map[line_type]
- line_column = fd.get("line_column")
+ line_column = fd["line_column"]
path = deser(d[line_column])
if fd.get("reverse_long_lat"):
path = [(o[1], o[0]) for o in path]
@@ -2540,7 +2560,7 @@ class DeckPathViz(BaseDeckGLViz):
return d
def get_data(self, df: pd.DataFrame) -> VizData:
- self.metric_label = utils.get_metric_name(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
return super().get_data(df)
@@ -2552,18 +2572,18 @@ class DeckPolygon(DeckPathViz):
deck_viz_key = "polygon"
verbose_name = _("Deck.gl - Polygon")
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
fd = self.form_data
self.elevation = fd.get("point_radius_fixed") or {"type": "fix", "value": 500}
return super().query_obj()
- def get_metrics(self):
+ def get_metrics(self) -> List[str]:
metrics = [self.form_data.get("metric")]
if self.elevation.get("type") == "metric":
metrics.append(self.elevation.get("value"))
return [metric for metric in metrics if metric]
- def get_properties(self, d):
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
super().get_properties(d)
fd = self.form_data
elevation = fd["point_radius_fixed"]["value"]
@@ -2582,11 +2602,14 @@ class DeckHex(BaseDeckGLViz):
verbose_name = _("Deck.gl - 3D HEX")
spatial_control_keys = ["spatial"]
- def get_properties(self, d):
- return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "position": d.get("spatial"),
+ "weight": (d.get(self.metric_label) if self.metric_label else None) or 1,
+ }
def get_data(self, df: pd.DataFrame) -> VizData:
- self.metric_label = utils.get_metric_name(self.metric)
+ self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
return super(DeckHex, self).get_data(df)
@@ -2597,15 +2620,15 @@ class DeckGeoJson(BaseDeckGLViz):
viz_type = "deck_geojson"
verbose_name = _("Deck.gl - GeoJSON")
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
d["columns"] += [self.form_data.get("geojson")]
d["metrics"] = []
d["groupby"] = []
return d
- def get_properties(self, d):
- geojson = d.get(self.form_data.get("geojson"))
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
+ geojson = d[self.form_data["geojson"]]
return json.loads(geojson)
@@ -2618,12 +2641,12 @@ class DeckArc(BaseDeckGLViz):
spatial_control_keys = ["start_spatial", "end_spatial"]
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
fd = self.form_data
self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
return super().query_obj()
- def get_properties(self, d):
+ def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
dim = self.form_data.get("dimension")
return {
"sourcePosition": d.get("start_spatial"),
@@ -2653,15 +2676,15 @@ class EventFlowViz(BaseViz):
credits = 'from <a href="https://github.com/williaster/data-ui">@data-ui</a>'
is_timeseries = True
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
query = super().query_obj()
form_data = self.form_data
- event_key = form_data.get("all_columns_x")
- entity_key = form_data.get("entity")
+ event_key = form_data["all_columns_x"]
+ entity_key = form_data["entity"]
meta_keys = [
col
- for col in form_data.get("all_columns")
+ for col in form_data["all_columns"]
if col != event_key and col != entity_key
]
@@ -2773,14 +2796,16 @@ class PartitionViz(NVD3TimeSeriesViz):
viz_type = "partition"
verbose_name = _("Partition Diagram")
- def query_obj(self):
+ def query_obj(self) -> QueryObjectDict:
query_obj = super().query_obj()
time_op = self.form_data.get("time_series_option", "not_time")
# Return time series data if the user specifies so
query_obj["is_timeseries"] = time_op != "not_time"
return query_obj
- def levels_for(self, time_op, groups, df):
+ def levels_for(
+ self, time_op: str, groups: List[str], df: pd.DataFrame
+ ) -> Dict[int, pd.Series]:
"""
Compute the partition at each `level` from the dataframe.
"""
@@ -2794,7 +2819,9 @@ class PartitionViz(NVD3TimeSeriesViz):
)
return levels
- def levels_for_diff(self, time_op, groups, df):
+ def levels_for_diff(
+ self, time_op: str, groups: List[str], df: pd.DataFrame
+ ) -> Dict[int, pd.DataFrame]:
# Obtain a unique list of the time grains
times = list(set(df[DTTM_ALIAS]))
times.sort()
@@ -2828,7 +2855,9 @@ class PartitionViz(NVD3TimeSeriesViz):
)
return levels
- def levels_for_time(self, groups, df):
+ def levels_for_time(
+ self, groups: List[str], df: pd.DataFrame
+ ) -> Dict[int, VizData]:
procs = {}
for i in range(0, len(groups) + 1):
self.form_data["groupby"] = groups[:i]
@@ -2837,11 +2866,19 @@ class PartitionViz(NVD3TimeSeriesViz):
self.form_data["groupby"] = groups
return procs
- def nest_values(self, levels, level=0, metric=None, dims=()):
+ def nest_values(
+ self,
+ levels: Dict[int, pd.DataFrame],
+ level: int = 0,
+ metric: Optional[str] = None,
+ dims: Optional[List[str]] = None,
+ ) -> List[Dict[str, Any]]:
"""
Nest values at each level on the back-end with
access and setting, instead of summing from the bottom.
"""
+ if dims is None:
+ dims = []
if not level:
return [
{
@@ -2856,7 +2893,7 @@ class PartitionViz(NVD3TimeSeriesViz):
{
"name": i,
"val": levels[1][metric][i],
- "children": self.nest_values(levels, 2, metric, (i,)),
+ "children": self.nest_values(levels, 2, metric, [i]),
}
for i in levels[1][metric].index
]
@@ -2866,12 +2903,20 @@ class PartitionViz(NVD3TimeSeriesViz):
{
"name": i,
"val": levels[level][metric][dims][i],
- "children": self.nest_values(levels, level + 1, metric, dims + (i,)),
+ "children": self.nest_values(levels, level + 1, metric, dims + [i]),
}
for i in levels[level][metric][dims].index
]
- def nest_procs(self, procs, level=-1, dims=(), time=None):
+ def nest_procs(
+ self,
+ procs: Dict[int, pd.DataFrame],
+ level: int = -1,
+ dims: Optional[Tuple[str, ...]] = None,
+ time: Any = None,
+ ) -> List[Dict[str, Any]]:
+ if dims is None:
+ dims = ()
if level == -1:
return [
{"name": m, "children": self.nest_procs(procs, 0, (m,))}
diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py
index f51580d..1df2528 100644
--- a/superset/viz_sip38.py
+++ b/superset/viz_sip38.py
@@ -20,6 +20,7 @@
These objects represent the backend of all the visualizations that
Superset can render.
"""
+# mypy: ignore-errors
import copy
import hashlib
import inspect
@@ -610,7 +611,7 @@ class TableViz(BaseViz):
raise QueryObjectValidationError(
_("Pick a granularity in the Time section or " "uncheck 'Include Time'")
)
- return fd.get("include_time")
+ return bool(fd.get("include_time"))
def query_obj(self):
d = super().query_obj()
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 748d50b..f8eb8ce 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -974,7 +974,7 @@ class BaseDeckGLVizTestCase(SupersetTestCase):
test_viz_deckgl = viz.DeckScatterViz(datasource, form_data)
test_viz_deckgl.point_radius_fixed = {}
result = test_viz_deckgl.get_metrics()
- assert result is None
+ assert result == []
def test_get_js_columns(self):
form_data = load_fixture("deck_path_form_data.json")