You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/06/03 22:26:32 UTC

[incubator-superset] branch master updated: style(mypy): Enforcing typing for superset (#9943)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 244677c  style(mypy): Enforcing typing for superset (#9943)
244677c is described below

commit 244677cf5e0ecb7c767455e96655af6c18cc58bc
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Jun 3 15:26:12 2020 -0700

    style(mypy): Enforcing typing for superset (#9943)
    
    Co-authored-by: John Bodley <jo...@airbnb.com>
---
 setup.cfg                 |   2 +-
 superset/app.py           |  66 +++++----
 superset/cli.py           |  56 ++++----
 superset/config.py        |  21 +--
 superset/exceptions.py    |   2 +-
 superset/extensions.py    |  45 +++---
 superset/forms.py         |  16 +--
 superset/jinja_context.py |  26 ++--
 superset/sql_lab.py       |  76 +++++-----
 superset/sql_parse.py     |   6 +-
 superset/stats_logger.py  |  27 ++--
 superset/typing.py        |   1 +
 superset/viz.py           | 357 ++++++++++++++++++++++++++--------------------
 superset/viz_sip38.py     |   3 +-
 tests/viz_tests.py        |   2 +-
 15 files changed, 393 insertions(+), 313 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index 1115de9..fc94a24 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -53,7 +53,7 @@ order_by_type = false
 ignore_missing_imports = true
 no_implicit_optional = true
 
-[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
+[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...]
 check_untyped_defs = true
 disallow_untyped_calls = true
 disallow_untyped_defs = true
diff --git a/superset/app.py b/superset/app.py
index 98f1459..18165ed 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -17,6 +17,7 @@
 
 import logging
 import os
+from typing import Any, Callable, Dict
 
 import wtforms_json
 from flask import Flask, redirect
@@ -41,13 +42,14 @@ from superset.extensions import (
     talisman,
 )
 from superset.security import SupersetSecurityManager
+from superset.typing import FlaskResponse
 from superset.utils.core import pessimistic_connection_handling
 from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
 
 logger = logging.getLogger(__name__)
 
 
-def create_app():
+def create_app() -> Flask:
     app = Flask(__name__)
 
     try:
@@ -68,7 +70,7 @@ def create_app():
 
 class SupersetIndexView(IndexView):
     @expose("/")
-    def index(self):
+    def index(self) -> FlaskResponse:
         return redirect("/superset/welcome")
 
 
@@ -109,8 +111,8 @@ class SupersetAppInitializer:
             abstract = True
 
             # Grab each call into the task and set up an app context
-            def __call__(self, *args, **kwargs):
-                with flask_app.app_context():
+            def __call__(self, *args: Any, **kwargs: Any) -> Any:
+                with flask_app.app_context():  # type: ignore
                     return task_base.__call__(self, *args, **kwargs)
 
         celery_app.Task = AppContextTask
@@ -454,51 +456,41 @@ class SupersetAppInitializer:
         order to fully init the app
         """
         self.pre_init()
-
         self.setup_db()
-
         self.configure_celery()
-
         self.setup_event_logger()
-
         self.setup_bundle_manifest()
-
         self.register_blueprints()
-
         self.configure_wtf()
-
         self.configure_logging()
-
         self.configure_middlewares()
-
         self.configure_cache()
-
         self.configure_jinja_context()
 
-        with self.flask_app.app_context():
+        with self.flask_app.app_context():  # type: ignore
             self.init_app_in_ctx()
 
         self.post_init()
 
-    def setup_event_logger(self):
+    def setup_event_logger(self) -> None:
         _event_logger["event_logger"] = get_event_logger_from_cfg_value(
             self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
         )
 
-    def configure_data_sources(self):
+    def configure_data_sources(self) -> None:
         # Registering sources
         module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
         module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
         ConnectorRegistry.register_sources(module_datasource_map)
 
-    def configure_cache(self):
+    def configure_cache(self) -> None:
         cache_manager.init_app(self.flask_app)
         results_backend_manager.init_app(self.flask_app)
 
-    def configure_feature_flags(self):
+    def configure_feature_flags(self) -> None:
         feature_flag_manager.init_app(self.flask_app)
 
-    def configure_fab(self):
+    def configure_fab(self) -> None:
         if self.config["SILENCE_FAB"]:
             logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
 
@@ -516,7 +508,7 @@ class SupersetAppInitializer:
         appbuilder.update_perms = False
         appbuilder.init_app(self.flask_app, db.session)
 
-    def configure_url_map_converters(self):
+    def configure_url_map_converters(self) -> None:
         #
         # Doing local imports here as model importing causes a reference to
         # app.config to be invoked and we need the current_app to have been setup
@@ -527,10 +519,10 @@ class SupersetAppInitializer:
         self.flask_app.url_map.converters["regex"] = RegexConverter
         self.flask_app.url_map.converters["object_type"] = ObjectTypeConverter
 
-    def configure_jinja_context(self):
+    def configure_jinja_context(self) -> None:
         jinja_context_manager.init_app(self.flask_app)
 
-    def configure_middlewares(self):
+    def configure_middlewares(self) -> None:
         if self.config["ENABLE_CORS"]:
             from flask_cors import CORS
 
@@ -539,24 +531,28 @@ class SupersetAppInitializer:
         if self.config["ENABLE_PROXY_FIX"]:
             from werkzeug.middleware.proxy_fix import ProxyFix
 
-            self.flask_app.wsgi_app = ProxyFix(
+            self.flask_app.wsgi_app = ProxyFix(  # type: ignore
                 self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"]
             )
 
         if self.config["ENABLE_CHUNK_ENCODING"]:
 
             class ChunkedEncodingFix:  # pylint: disable=too-few-public-methods
-                def __init__(self, app):
+                def __init__(self, app: Flask) -> None:
                     self.app = app
 
-                def __call__(self, environ, start_response):
+                def __call__(
+                    self, environ: Dict[str, Any], start_response: Callable
+                ) -> Any:
                     # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
                     # content-length and read the stream till the end.
                     if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked":
                         environ["wsgi.input_terminated"] = True
                     return self.app(environ, start_response)
 
-            self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app)
+            self.flask_app.wsgi_app = ChunkedEncodingFix(  # type: ignore
+                self.flask_app.wsgi_app  # type: ignore
+            )
 
         if self.config["UPLOAD_FOLDER"]:
             try:
@@ -565,7 +561,9 @@ class SupersetAppInitializer:
                 pass
 
         for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
-            self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app)
+            self.flask_app.wsgi_app = middleware(  # type: ignore
+                self.flask_app.wsgi_app
+            )
 
         # Flask-Compress
         if self.config["ENABLE_FLASK_COMPRESS"]:
@@ -574,27 +572,27 @@ class SupersetAppInitializer:
         if self.config["TALISMAN_ENABLED"]:
             talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"])
 
-    def configure_logging(self):
+    def configure_logging(self) -> None:
         self.config["LOGGING_CONFIGURATOR"].configure_logging(
             self.config, self.flask_app.debug
         )
 
-    def setup_db(self):
+    def setup_db(self) -> None:
         db.init_app(self.flask_app)
 
-        with self.flask_app.app_context():
+        with self.flask_app.app_context():  # type: ignore
             pessimistic_connection_handling(db.engine)
 
         migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations")
 
-    def configure_wtf(self):
+    def configure_wtf(self) -> None:
         if self.config["WTF_CSRF_ENABLED"]:
             csrf = CSRFProtect(self.flask_app)
             csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"]
             for ex in csrf_exempt_list:
                 csrf.exempt(ex)
 
-    def register_blueprints(self):
+    def register_blueprints(self) -> None:
         for bp in self.config["BLUEPRINTS"]:
             try:
                 logger.info(f"Registering blueprint: '{bp.name}'")
@@ -602,5 +600,5 @@ class SupersetAppInitializer:
             except Exception:  # pylint: disable=broad-except
                 logger.exception("blueprint registration failed")
 
-    def setup_bundle_manifest(self):
+    def setup_bundle_manifest(self) -> None:
         manifest_processor.init_app(self.flask_app)
diff --git a/superset/cli.py b/superset/cli.py
index 090e1fb..a136010 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -19,10 +19,11 @@ import logging
 from datetime import datetime
 from subprocess import Popen
 from sys import stdout
-from typing import Type, Union
+from typing import Any, Dict, Type, Union
 
 import click
 import yaml
+from celery.utils.abstract import CallableTask
 from colorama import Fore, Style
 from flask import g
 from flask.cli import FlaskGroup, with_appcontext
@@ -56,17 +57,17 @@ def normalize_token(token_name: str) -> str:
     context_settings={"token_normalize_func": normalize_token},
 )
 @with_appcontext
-def superset():
+def superset() -> None:
     """This is a management script for the Superset application."""
 
     @app.shell_context_processor
-    def make_shell_context():  # pylint: disable=unused-variable
+    def make_shell_context() -> Dict[str, Any]:  # pylint: disable=unused-variable
         return dict(app=app, db=db)
 
 
 @superset.command()
 @with_appcontext
-def init():
+def init() -> None:
     """Inits the Superset application"""
     appbuilder.add_permissions(update_perms=True)
     security_manager.sync_role_definitions()
@@ -75,7 +76,7 @@ def init():
 @superset.command()
 @with_appcontext
 @click.option("--verbose", "-v", is_flag=True, help="Show extra information")
-def version(verbose):
+def version(verbose: bool) -> None:
     """Prints the current version number"""
     print(Fore.BLUE + "-=" * 15)
     print(
@@ -90,7 +91,9 @@ def version(verbose):
     print(Style.RESET_ALL)
 
 
-def load_examples_run(load_test_data, only_metadata=False, force=False):
+def load_examples_run(
+    load_test_data: bool, only_metadata: bool = False, force: bool = False
+) -> None:
     if only_metadata:
         print("Loading examples metadata")
     else:
@@ -160,7 +163,9 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
 @click.option(
     "--force", "-f", is_flag=True, help="Force load data even if table already exists"
 )
-def load_examples(load_test_data, only_metadata=False, force=False):
+def load_examples(
+    load_test_data: bool, only_metadata: bool = False, force: bool = False
+) -> None:
     """Loads a set of Slices and Dashboards and a supporting dataset """
     load_examples_run(load_test_data, only_metadata, force)
 
@@ -169,7 +174,7 @@ def load_examples(load_test_data, only_metadata=False, force=False):
 @superset.command()
 @click.option("--database_name", "-d", help="Database name to change")
 @click.option("--uri", "-u", help="Database URI to change")
-def set_database_uri(database_name, uri):
+def set_database_uri(database_name: str, uri: str) -> None:
     """Updates a database connection URI """
     utils.get_or_create_db(database_name, uri)
 
@@ -189,7 +194,7 @@ def set_database_uri(database_name, uri):
     default=False,
     help="Specify using 'merge' property during operation. " "Default value is False.",
 )
-def refresh_druid(datasource, merge):
+def refresh_druid(datasource: str, merge: bool) -> None:
     """Refresh druid datasources"""
     session = db.session()
     from superset.connectors.druid.models import DruidCluster
@@ -226,7 +231,7 @@ def refresh_druid(datasource, merge):
     default=None,
     help="Specify the user name to assign dashboards to",
 )
-def import_dashboards(path, recursive, username):
+def import_dashboards(path: str, recursive: bool, username: str) -> None:
     """Import dashboards from JSON"""
     from superset.utils import dashboard_import_export
 
@@ -258,7 +263,7 @@ def import_dashboards(path, recursive, username):
 @click.option(
     "--print_stdout", "-p", is_flag=True, default=False, help="Print JSON to stdout"
 )
-def export_dashboards(print_stdout, dashboard_file):
+def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
     """Export dashboards to JSON"""
     from superset.utils import dashboard_import_export
 
@@ -295,7 +300,7 @@ def export_dashboards(print_stdout, dashboard_file):
     default=False,
     help="recursively search the path for yaml files",
 )
-def import_datasources(path, sync, recursive):
+def import_datasources(path: str, sync: str, recursive: bool) -> None:
     """Import datasources from YAML"""
     from superset.utils import dict_import_export
 
@@ -345,8 +350,11 @@ def import_datasources(path, sync, recursive):
     help="Include fields containing defaults",
 )
 def export_datasources(
-    print_stdout, datasource_file, back_references, include_defaults
-):
+    print_stdout: bool,
+    datasource_file: str,
+    back_references: bool,
+    include_defaults: bool,
+) -> None:
     """Export datasources to YAML"""
     from superset.utils import dict_import_export
 
@@ -373,7 +381,7 @@ def export_datasources(
     default=False,
     help="Include parent back references",
 )
-def export_datasource_schema(back_references):
+def export_datasource_schema(back_references: bool) -> None:
     """Export datasource YAML schema to stdout"""
     from superset.utils import dict_import_export
 
@@ -383,7 +391,7 @@ def export_datasource_schema(back_references):
 
 @superset.command()
 @with_appcontext
-def update_datasources_cache():
+def update_datasources_cache() -> None:
     """Refresh sqllab datasources cache"""
     from superset.models.core import Database
 
@@ -406,7 +414,7 @@ def update_datasources_cache():
 @click.option(
     "--workers", "-w", type=int, help="Number of celery server workers to fire up"
 )
-def worker(workers):
+def worker(workers: int) -> None:
     """Starts a Superset worker for async SQL query execution."""
     logger.info(
         "The 'superset worker' command is deprecated. Please use the 'celery "
@@ -431,7 +439,7 @@ def worker(workers):
 @click.option(
     "-a", "--address", default="localhost", help="Address on which to run the service"
 )
-def flower(port, address):
+def flower(port: int, address: str) -> None:
     """Runs a Celery Flower web server
 
     Celery Flower is a UI to monitor the Celery operation on a given
@@ -487,7 +495,7 @@ def compute_thumbnails(
     charts_only: bool,
     force: bool,
     model_id: int,
-):
+) -> None:
     """Compute thumbnails"""
     from superset.models.dashboard import Dashboard
     from superset.models.slice import Slice
@@ -500,8 +508,8 @@ def compute_thumbnails(
         friendly_type: str,
         model_cls: Union[Type[Dashboard], Type[Slice]],
         model_id: int,
-        compute_func,
-    ):
+        compute_func: CallableTask,
+    ) -> None:
         query = db.session.query(model_cls)
         if model_id:
             query = query.filter(model_cls.id.in_(model_id))
@@ -528,7 +536,7 @@ def compute_thumbnails(
 
 @superset.command()
 @with_appcontext
-def load_test_users():
+def load_test_users() -> None:
     """
     Loads admin, alpha, and gamma user for testing purposes
 
@@ -538,7 +546,7 @@ def load_test_users():
     load_test_users_run()
 
 
-def load_test_users_run():
+def load_test_users_run() -> None:
     """
     Loads admin, alpha, and gamma user for testing purposes
 
@@ -583,7 +591,7 @@ def load_test_users_run():
 
 @superset.command()
 @with_appcontext
-def sync_tags():
+def sync_tags() -> None:
     """Rebuilds special tags (owner, type, favorited by)."""
     # pylint: disable=no-member
     metadata = Model.metadata
diff --git a/superset/config.py b/superset/config.py
index 738c251..35dbbf8 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -28,8 +28,9 @@ import os
 import sys
 from collections import OrderedDict
 from datetime import date
-from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
+from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
 
+from cachelib.base import BaseCache
 from celery.schedules import crontab
 from dateutil import tz
 from flask_appbuilder.security.manager import AUTH_DB
@@ -78,7 +79,7 @@ PACKAGE_JSON_FILE = os.path.join(BASE_DIR, "static", "assets", "package.json")
 FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
 
 
-def _try_json_readversion(filepath):
+def _try_json_readversion(filepath: str) -> Optional[str]:
     try:
         with open(filepath, "r") as f:
             return json.load(f).get("version")
@@ -86,7 +87,9 @@ def _try_json_readversion(filepath):
         return None
 
 
-def _try_json_readsha(filepath, length):  # pylint: disable=unused-argument
+def _try_json_readsha(  # pylint: disable=unused-argument
+    filepath: str, length: int
+) -> Optional[str]:
     try:
         with open(filepath, "r") as f:
             return json.load(f).get("GIT_SHA")[:length]
@@ -453,6 +456,7 @@ BACKUP_COUNT = 30
 #     user=None,
 #     client=None,
 #     security_manager=None,
+#     log_params=None,
 # ):
 #     pass
 QUERY_LOGGER = None
@@ -578,10 +582,9 @@ SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
     Callable[["Database", "models.User", str, str], str]
 ] = None
 
-# An instantiated derivative of cachelib.base.BaseCache
-# if enabled, it can be used to store the results of long-running queries
+# If enabled, it can be used to store the results of long-running queries
 # in SQL Lab by using the "Run Async" button/feature
-RESULTS_BACKEND = None
+RESULTS_BACKEND: Optional[BaseCache] = None
 
 # Use PyArrow and MessagePack for async query results serialization,
 # rather than JSON. This feature requires additional testing from the
@@ -604,7 +607,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[
 
 # The namespace within hive where the tables created from
 # uploading CSVs will be stored.
-UPLOADED_CSV_HIVE_NAMESPACE = None
+UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
 
 # Function that computes the allowed schemas for the CSV uploads.
 # Allowed schemas will be a union of schemas_allowed_for_csv_upload
@@ -614,7 +617,7 @@ UPLOADED_CSV_HIVE_NAMESPACE = None
 ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
     ["Database", "models.User"], List[str]
 ] = lambda database, user: [
-    UPLOADED_CSV_HIVE_NAMESPACE  # type: ignore
+    UPLOADED_CSV_HIVE_NAMESPACE
 ] if UPLOADED_CSV_HIVE_NAMESPACE else []
 
 # A dictionary of items that gets merged into the Jinja context for
@@ -628,7 +631,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
 # dictionary, which means the existing keys get overwritten by the content of this
 # dictionary. The customized addons don't necessarily need to use jinjia templating
 # language. This allows you to define custom logic to process macro template.
-CUSTOM_TEMPLATE_PROCESSORS = {}  # type: Dict[str, BaseTemplateProcessor]
+CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
 
 # Roles that are controlled by the API / Superset and should not be changes
 # by humans.
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 59ea042..51bd85f 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -32,7 +32,7 @@ class SupersetException(Exception):
         super().__init__(self.message)
 
     @property
-    def exception(self):
+    def exception(self) -> Optional[Exception]:
         return self._exception
 
 
diff --git a/superset/extensions.py b/superset/extensions.py
index c501eeb..f321046 100644
--- a/superset/extensions.py
+++ b/superset/extensions.py
@@ -20,10 +20,12 @@ import random
 import time
 import uuid
 from datetime import datetime, timedelta
-from typing import Dict, TYPE_CHECKING  # pylint: disable=unused-import
+from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
 
 import celery
+from cachelib.base import BaseCache
 from dateutil.relativedelta import relativedelta
+from flask import Flask
 from flask_appbuilder import AppBuilder, SQLA
 from flask_migrate import Migrate
 from flask_talisman import Talisman
@@ -32,7 +34,6 @@ from werkzeug.local import LocalProxy
 from superset.utils.cache_manager import CacheManager
 from superset.utils.feature_flag_manager import FeatureFlagManager
 
-# Avoid circular import
 if TYPE_CHECKING:
     from superset.jinja_context import (  # pylint: disable=unused-import
         BaseTemplateProcessor,
@@ -49,18 +50,18 @@ class JinjaContextManager:
             "timedelta": timedelta,
             "uuid": uuid,
         }
-        self._template_processors = {}  # type: Dict[str, BaseTemplateProcessor]
+        self._template_processors: Dict[str, Type["BaseTemplateProcessor"]] = {}
 
-    def init_app(self, app):
+    def init_app(self, app: Flask) -> None:
         self._base_context.update(app.config["JINJA_CONTEXT_ADDONS"])
         self._template_processors.update(app.config["CUSTOM_TEMPLATE_PROCESSORS"])
 
     @property
-    def base_context(self):
+    def base_context(self) -> Dict[str, Any]:
         return self._base_context
 
     @property
-    def template_processors(self):
+    def template_processors(self) -> Dict[str, Type["BaseTemplateProcessor"]]:
         return self._template_processors
 
 
@@ -69,35 +70,35 @@ class ResultsBackendManager:
         self._results_backend = None
         self._use_msgpack = False
 
-    def init_app(self, app):
-        self._results_backend = app.config.get("RESULTS_BACKEND")
-        self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK")
+    def init_app(self, app: Flask) -> None:
+        self._results_backend = app.config["RESULTS_BACKEND"]
+        self._use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
 
     @property
-    def results_backend(self):
+    def results_backend(self) -> Optional[BaseCache]:
         return self._results_backend
 
     @property
-    def should_use_msgpack(self):
+    def should_use_msgpack(self) -> bool:
         return self._use_msgpack
 
 
 class UIManifestProcessor:
     def __init__(self, app_dir: str) -> None:
-        self.app = None
-        self.manifest: dict = {}
+        self.app: Optional[Flask] = None
+        self.manifest: Dict[str, Dict[str, List[str]]] = {}
         self.manifest_file = f"{app_dir}/static/assets/manifest.json"
 
-    def init_app(self, app):
+    def init_app(self, app: Flask) -> None:
         self.app = app
         # Preload the cache
         self.parse_manifest_json()
 
         @app.context_processor
-        def get_manifest():  # pylint: disable=unused-variable
+        def get_manifest() -> Dict[str, Callable]:  # pylint: disable=unused-variable
             loaded_chunks = set()
 
-            def get_files(bundle, asset_type="js"):
+            def get_files(bundle: str, asset_type: str = "js") -> List[str]:
                 files = self.get_manifest_files(bundle, asset_type)
                 filtered_files = [f for f in files if f not in loaded_chunks]
                 for f in filtered_files:
@@ -109,18 +110,18 @@ class UIManifestProcessor:
                 css_manifest=lambda bundle: get_files(bundle, "css"),
             )
 
-    def parse_manifest_json(self):
+    def parse_manifest_json(self) -> None:
         try:
             with open(self.manifest_file, "r") as f:
-                # the manifest includes non-entry files
-                # we only need entries in templates
+                # the manifest includes non-entry files we only need entries in
+                # templates
                 full_manifest = json.load(f)
                 self.manifest = full_manifest.get("entrypoints", {})
         except Exception:  # pylint: disable=broad-except
             pass
 
-    def get_manifest_files(self, bundle, asset_type):
-        if self.app.debug:
+    def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]:
+        if self.app and self.app.debug:
             self.parse_manifest_json()
         return self.manifest.get(bundle, {}).get(asset_type, [])
 
@@ -133,7 +134,7 @@ db = SQLA()
 _event_logger: dict = {}
 event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
 feature_flag_manager = FeatureFlagManager()
-jinja_context_manager = JinjaContextManager()  # type: JinjaContextManager
+jinja_context_manager = JinjaContextManager()
 manifest_processor = UIManifestProcessor(APP_DIR)
 migrate = Migrate()
 results_backend_manager = ResultsBackendManager()
diff --git a/superset/forms.py b/superset/forms.py
index 175903a..4ba3ca2 100644
--- a/superset/forms.py
+++ b/superset/forms.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Contains the logic to create cohesive forms on the explore view"""
-from typing import List  # pylint: disable=unused-import
+from typing import Any, List, Optional
 
 from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
 from wtforms import Field
@@ -25,24 +25,24 @@ class CommaSeparatedListField(Field):
     widget = BS3TextFieldWidget()
     data: List[str] = []
 
-    def _value(self):
+    def _value(self) -> str:
         if self.data:
-            return u", ".join(self.data)
+            return ", ".join(self.data)
 
-        return u""
+        return ""
 
-    def process_formdata(self, valuelist):
+    def process_formdata(self, valuelist: List[str]) -> None:
         if valuelist:
             self.data = [x.strip() for x in valuelist[0].split(",")]
         else:
             self.data = []
 
 
-def filter_not_empty_values(value):
+def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]:
     """Returns a list of non empty values or None"""
-    if not value:
+    if not values:
         return None
-    data = [x for x in value if x]
+    data = [value for value in values if value]
     if not data:
         return None
     return data
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index e1a10cd..95ee723 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -17,7 +17,7 @@
 """Defines the templating context for SQL Lab"""
 import inspect
 import re
-from typing import Any, List, Optional, Tuple, TYPE_CHECKING
+from typing import Any, cast, List, Optional, Tuple, TYPE_CHECKING
 
 from flask import g, request
 from jinja2.sandbox import SandboxedEnvironment
@@ -207,7 +207,7 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
 
     def __init__(
         self,
-        database: Optional["Database"] = None,
+        database: "Database",
         query: Optional["Query"] = None,
         table: Optional["SqlaTable"] = None,
         extra_cache_keys: Optional[List[Any]] = None,
@@ -266,7 +266,7 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
             schema, table_name = table_name.split(".")
         return table_name, schema
 
-    def first_latest_partition(self, table_name: str) -> str:
+    def first_latest_partition(self, table_name: str) -> Optional[str]:
         """
         Gets the first value in the array of all latest partitions
 
@@ -275,9 +275,10 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
         :raises IndexError: If no partition exists
         """
 
-        return self.latest_partitions(table_name)[0]
+        latest_partitions = self.latest_partitions(table_name)
+        return latest_partitions[0] if latest_partitions else None
 
-    def latest_partitions(self, table_name: str) -> List[str]:
+    def latest_partitions(self, table_name: str) -> Optional[List[str]]:
         """
         Gets the array of all latest partitions
 
@@ -285,16 +286,21 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
         :return: the latest partition array
         """
 
+        from superset.db_engine_specs.presto import PrestoEngineSpec
+
         table_name, schema = self._schema_table(table_name, self.schema)
-        assert self.database
-        return self.database.db_engine_spec.latest_partition(  # type: ignore
+        return cast(PrestoEngineSpec, self.database.db_engine_spec).latest_partition(
             table_name, schema, self.database
         )[1]
 
-    def latest_sub_partition(self, table_name, **kwargs):
+    def latest_sub_partition(self, table_name: str, **kwargs: Any) -> Any:
         table_name, schema = self._schema_table(table_name, self.schema)
-        assert self.database
-        return self.database.db_engine_spec.latest_sub_partition(
+
+        from superset.db_engine_specs.presto import PrestoEngineSpec
+
+        return cast(
+            PrestoEngineSpec, self.database.db_engine_spec
+        ).latest_sub_partition(
             table_name=table_name, schema=schema, database=self.database, **kwargs
         )
 
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index ab952db..3ba1e3a 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -19,7 +19,7 @@ import uuid
 from contextlib import closing
 from datetime import datetime
 from sys import getsizeof
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
 
 import backoff
 import msgpack
@@ -27,9 +27,10 @@ import pyarrow as pa
 import simplejson as json
 import sqlalchemy
 from celery.exceptions import SoftTimeLimitExceeded
+from celery.task.base import Task
 from contextlib2 import contextmanager
 from flask_babel import lazy_gettext as _
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import Session, sessionmaker
 from sqlalchemy.pool import NullPool
 
 from superset import (
@@ -77,7 +78,9 @@ class SqlLabTimeoutException(SqlLabException):
     pass
 
 
-def handle_query_error(msg, query, session, payload=None):
+def handle_query_error(
+    msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None
+) -> Dict[str, Any]:
     """Local method handling error while processing the SQL"""
     payload = payload or {}
     troubleshooting_link = config["TROUBLESHOOTING_LINK"]
@@ -91,14 +94,14 @@ def handle_query_error(msg, query, session, payload=None):
     return payload
 
 
-def get_query_backoff_handler(details):
+def get_query_backoff_handler(details: Dict[Any, Any]) -> None:
     query_id = details["kwargs"]["query_id"]
     logger.error(f"Query with id `{query_id}` could not be retrieved")
     stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
     logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
 
 
-def get_query_giveup_handler(_):
+def get_query_giveup_handler(_: Any) -> None:
     stats_logger.incr("error_failed_at_getting_orm_query")
 
 
@@ -110,7 +113,7 @@ def get_query_giveup_handler(_):
     on_giveup=get_query_giveup_handler,
     max_tries=5,
 )
-def get_query(query_id, session):
+def get_query(query_id: int, session: Session) -> Query:
     """attempts to get the query and retry if it cannot"""
     try:
         return session.query(Query).filter_by(id=query_id).one()
@@ -119,7 +122,7 @@ def get_query(query_id, session):
 
 
 @contextmanager
-def session_scope(nullpool):
+def session_scope(nullpool: bool) -> Iterator[Session]:
     """Provide a transactional scope around a series of operations."""
     database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
     if "sqlite" in database_uri:
@@ -154,16 +157,16 @@ def session_scope(nullpool):
     soft_time_limit=SQLLAB_TIMEOUT,
 )
 def get_sql_results(  # pylint: disable=too-many-arguments
-    ctask,
-    query_id,
-    rendered_query,
-    return_results=True,
-    store_results=False,
-    user_name=None,
-    start_time=None,
-    expand_data=False,
-    log_params=None,
-):
+    ctask: Task,
+    query_id: int,
+    rendered_query: str,
+    return_results: bool = True,
+    store_results: bool = False,
+    user_name: Optional[str] = None,
+    start_time: Optional[float] = None,
+    expand_data: bool = False,
+    log_params: Optional[Dict[str, Any]] = None,
+) -> Optional[Dict[str, Any]]:
     """Executes the sql query returns the results."""
     with session_scope(not ctask.request.called_directly) as session:
 
@@ -188,7 +191,14 @@ def get_sql_results(  # pylint: disable=too-many-arguments
 
 
 # pylint: disable=too-many-arguments
-def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_params):
+def execute_sql_statement(
+    sql_statement: str,
+    query: Query,
+    user_name: Optional[str],
+    session: Session,
+    cursor: Any,
+    log_params: Optional[Dict[str, Any]],
+) -> SupersetResultSet:
     """Executes a single SQL statement"""
     database = query.database
     db_engine_spec = database.db_engine_spec
@@ -275,7 +285,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
 
 
 def _serialize_payload(
-    payload: dict, use_msgpack: Optional[bool] = False
+    payload: Dict[Any, Any], use_msgpack: Optional[bool] = False
 ) -> Union[bytes, str]:
     logger.debug(f"Serializing to msgpack: {use_msgpack}")
     if use_msgpack:
@@ -321,24 +331,24 @@ def _serialize_and_expand_data(
     return (data, selected_columns, all_columns, expanded_columns)
 
 
-def execute_sql_statements(
-    query_id,
-    rendered_query,
-    return_results=True,
-    store_results=False,
-    user_name=None,
-    session=None,
-    start_time=None,
-    expand_data=False,
-    log_params=None,
-):  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
+def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
+    query_id: int,
+    rendered_query: str,
+    return_results: bool,
+    store_results: bool,
+    user_name: Optional[str],
+    session: Session,
+    start_time: Optional[float],
+    expand_data: bool,
+    log_params: Optional[Dict[str, Any]],
+) -> Optional[Dict[str, Any]]:
     """Executes the sql query returns the results."""
     if store_results and start_time:
         # only asynchronous queries
         stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
 
     query = get_query(query_id, session)
-    payload = dict(query_id=query_id)
+    payload: Dict[str, Any] = dict(query_id=query_id)
     database = query.database
     db_engine_spec = database.db_engine_spec
     db_engine_spec.patch()
@@ -406,7 +416,7 @@ def execute_sql_statements(
         )
     query.end_time = now_as_float()
 
-    use_arrow_data = store_results and results_backend_use_msgpack
+    use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
     data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
         result_set, db_engine_spec, use_arrow_data, expand_data
     )
@@ -432,7 +442,7 @@ def execute_sql_statements(
                 "sqllab.query.results_backend_write_serialization", stats_logger
             ):
                 serialized_payload = _serialize_payload(
-                    payload, results_backend_use_msgpack
+                    payload, cast(bool, results_backend_use_msgpack)
                 )
             cache_timeout = database.cache_timeout
             if cache_timeout is None:
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index be9cf10..3e50386 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -158,7 +158,7 @@ class ParsedQuery:
     def _is_identifier(token: Token) -> bool:
         return isinstance(token, (IdentifierList, Identifier))
 
-    def _process_tokenlist(self, token_list: TokenList):
+    def _process_tokenlist(self, token_list: TokenList) -> None:
         """
         Add table names to table set
 
@@ -204,7 +204,9 @@ class ParsedQuery:
         exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
         return exec_sql
 
-    def _extract_from_token(self, token: Token):  # pylint: disable=too-many-branches
+    def _extract_from_token(  # pylint: disable=too-many-branches
+        self, token: Token
+    ) -> None:
         """
         Populate self._tables from token
 
diff --git a/superset/stats_logger.py b/superset/stats_logger.py
index 37fe3d3..75cfd8a 100644
--- a/superset/stats_logger.py
+++ b/superset/stats_logger.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
+from typing import Optional
 
 from colorama import Fore, Style
 
@@ -40,7 +41,7 @@ class BaseStatsLogger:
         """Decrement a counter"""
         raise NotImplementedError()
 
-    def timing(self, key, value: float) -> None:
+    def timing(self, key: str, value: float) -> None:
         raise NotImplementedError()
 
     def gauge(self, key: str) -> None:
@@ -49,18 +50,18 @@ class BaseStatsLogger:
 
 
 class DummyStatsLogger(BaseStatsLogger):
-    def incr(self, key):
+    def incr(self, key: str) -> None:
         logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL)
 
-    def decr(self, key):
+    def decr(self, key: str) -> None:
         logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL))
 
-    def timing(self, key, value):
+    def timing(self, key: str, value: float) -> None:
         logger.debug(
             (Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL)
         )
 
-    def gauge(self, key):
+    def gauge(self, key: str) -> None:
         logger.debug(
             (Fore.CYAN + "[stats_logger] (gauge) " + f"{key}" + Style.RESET_ALL)
         )
@@ -71,8 +72,12 @@ try:
 
     class StatsdStatsLogger(BaseStatsLogger):
         def __init__(  # pylint: disable=super-init-not-called
-            self, host="localhost", port=8125, prefix="superset", statsd_client=None
-        ):
+            self,
+            host: str = "localhost",
+            port: int = 8125,
+            prefix: str = "superset",
+            statsd_client: Optional[StatsClient] = None,
+        ) -> None:
             """
             Initializes from either params or a supplied, pre-constructed statsd client.
 
@@ -84,16 +89,16 @@ try:
             else:
                 self.client = StatsClient(host=host, port=port, prefix=prefix)
 
-        def incr(self, key):
+        def incr(self, key: str) -> None:
             self.client.incr(key)
 
-        def decr(self, key):
+        def decr(self, key: str) -> None:
             self.client.decr(key)
 
-        def timing(self, key, value):
+        def timing(self, key: str, value: float) -> None:
             self.client.timing(key, value)
 
-        def gauge(self, key):
+        def gauge(self, key: str) -> None:
             # pylint: disable=no-value-for-parameter
             self.client.gauge(key)
 
diff --git a/superset/typing.py b/superset/typing.py
index 09a3393..e238000 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -33,6 +33,7 @@ Granularity = Union[str, Dict[str, Union[str, float]]]
 Metric = Union[Dict[str, str], str]
 QueryObjectDict = Dict[str, Any]
 VizData = Optional[Union[List[Any], Dict[Any, Any]]]
+VizPayload = Dict[str, Any]
 
 # Flask response.
 Base = Union[bytes, str]
diff --git a/superset/viz.py b/superset/viz.py
index bf3c110..2e38bf2 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -31,7 +31,7 @@ import uuid
 from collections import defaultdict, OrderedDict
 from datetime import datetime, timedelta
 from itertools import product
-from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
+from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
 
 import dataclasses
 import geohash
@@ -55,7 +55,7 @@ from superset.exceptions import (
     SpatialException,
 )
 from superset.models.helpers import QueryResult
-from superset.typing import VizData
+from superset.typing import QueryObjectDict, VizData, VizPayload
 from superset.utils import core as utils
 from superset.utils.core import (
     DTTM_ALIAS,
@@ -101,7 +101,7 @@ class BaseViz:
         datasource: "BaseDatasource",
         form_data: Dict[str, Any],
         force: bool = False,
-    ):
+    ) -> None:
         if not datasource:
             raise Exception(_("Viz is missing a datasource"))
 
@@ -134,7 +134,7 @@ class BaseViz:
 
         self.process_metrics()
 
-    def process_metrics(self):
+    def process_metrics(self) -> None:
         # metrics in TableViz is order sensitive, so metric_dict should be
         # OrderedDict
         self.metric_dict = OrderedDict()
@@ -153,8 +153,10 @@ class BaseViz:
         self.metric_labels = list(self.metric_dict.keys())
 
     @staticmethod
-    def handle_js_int_overflow(data):
-        for d in data.get("records", dict()):
+    def handle_js_int_overflow(
+        data: Dict[str, List[Dict[str, Any]]]
+    ) -> Dict[str, List[Dict[str, Any]]]:
+        for d in data.get("records", {}):
             for k, v in list(d.items()):
                 if isinstance(v, int):
                     # if an int is too big for Java Script to handle
@@ -163,7 +165,7 @@ class BaseViz:
                         d[k] = str(v)
         return data
 
-    def run_extra_queries(self):
+    def run_extra_queries(self) -> None:
         """Lifecycle method to use when more than one query is needed
 
         In rare-ish cases, a visualization may need to execute multiple
@@ -186,7 +188,7 @@ class BaseViz:
         """
         pass
 
-    def apply_rolling(self, df):
+    def apply_rolling(self, df: pd.DataFrame) -> pd.DataFrame:
         fd = self.form_data
         rolling_type = fd.get("rolling_type")
         rolling_periods = int(fd.get("rolling_periods") or 0)
@@ -206,7 +208,7 @@ class BaseViz:
             df = df[min_periods:]
         return df
 
-    def get_samples(self):
+    def get_samples(self) -> List[Dict[str, Any]]:
         query_obj = self.query_obj()
         query_obj.update(
             {
@@ -219,7 +221,7 @@ class BaseViz:
         df = self.get_df(query_obj)
         return df.to_dict(orient="records")
 
-    def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+    def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
         """Returns a pandas dataframe based on the query object"""
         if not query_obj:
             query_obj = self.query_obj()
@@ -281,19 +283,19 @@ class BaseViz:
             df.replace([np.inf, -np.inf], np.nan, inplace=True)
         return df
 
-    def df_metrics_to_num(self, df):
+    def df_metrics_to_num(self, df: pd.DataFrame) -> None:
         """Converting metrics to numeric when pandas.read_sql cannot"""
         metrics = self.metric_labels
         for col, dtype in df.dtypes.items():
             if dtype.type == np.object_ and col in metrics:
                 df[col] = pd.to_numeric(df[col], errors="coerce")
 
-    def process_query_filters(self):
+    def process_query_filters(self) -> None:
         utils.convert_legacy_filters_into_adhoc(self.form_data)
         merge_extra_filters(self.form_data)
         utils.split_adhoc_filters_into_base_filters(self.form_data)
 
-    def query_obj(self) -> Dict[str, Any]:
+    def query_obj(self) -> QueryObjectDict:
         """Building a query object"""
         form_data = self.form_data
         self.process_query_filters()
@@ -362,9 +364,9 @@ class BaseViz:
         return d
 
     @property
-    def cache_timeout(self):
+    def cache_timeout(self) -> int:
         if self.form_data.get("cache_timeout") is not None:
-            return int(self.form_data.get("cache_timeout"))
+            return int(self.form_data["cache_timeout"])
         if self.datasource.cache_timeout is not None:
             return self.datasource.cache_timeout
         if (
@@ -374,12 +376,12 @@ class BaseViz:
             return self.datasource.database.cache_timeout
         return config["CACHE_DEFAULT_TIMEOUT"]
 
-    def get_json(self):
+    def get_json(self) -> str:
         return json.dumps(
             self.get_payload(), default=utils.json_int_dttm_ser, ignore_nan=True
         )
 
-    def cache_key(self, query_obj, **extra):
+    def cache_key(self, query_obj: QueryObjectDict, **extra: Any) -> str:
         """
         The cache key is made out of the key/values in `query_obj`, plus any
         other key/values in `extra`.
@@ -410,7 +412,7 @@ class BaseViz:
         json_data = self.json_dumps(cache_dict, sort_keys=True)
         return hashlib.md5(json_data.encode("utf-8")).hexdigest()
 
-    def get_payload(self, query_obj=None):
+    def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload:
         """Returns a payload of metadata and data"""
         self.run_extra_queries()
         payload = self.get_df_payload(query_obj)
@@ -422,7 +424,9 @@ class BaseViz:
             del payload["df"]
         return payload
 
-    def get_df_payload(self, query_obj=None, **kwargs):
+    def get_df_payload(
+        self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
+    ) -> Dict[str, Any]:
         """Handles caching around the df payload retrieval"""
         if not query_obj:
             query_obj = self.query_obj()
@@ -512,21 +516,21 @@ class BaseViz:
             "rowcount": len(df.index) if df is not None else 0,
         }
 
-    def json_dumps(self, obj, sort_keys=False):
+    def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
         return json.dumps(
             obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
         )
 
-    def payload_json_and_has_error(self, payload):
+    def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
         has_error = (
             payload.get("status") == utils.QueryStatus.FAILED
             or payload.get("error") is not None
-            or len(payload.get("errors")) > 0
+            or len(payload.get("errors") or []) > 0
         )
         return self.json_dumps(payload), has_error
 
     @property
-    def data(self):
+    def data(self) -> Dict[str, Any]:
         """This is the data object serialized to the js layer"""
         content = {
             "form_data": self.form_data,
@@ -536,7 +540,7 @@ class BaseViz:
         }
         return content
 
-    def get_csv(self):
+    def get_csv(self) -> Optional[str]:
         df = self.get_df()
         include_index = not isinstance(df.index, pd.RangeIndex)
         return df.to_csv(index=include_index, **config["CSV_EXPORT"])
@@ -545,7 +549,7 @@ class BaseViz:
         return df.to_dict(orient="records")
 
     @property
-    def json_data(self):
+    def json_data(self) -> str:
         return json.dumps(self.data)
 
 
@@ -559,7 +563,7 @@ class TableViz(BaseViz):
     is_timeseries = False
     enforce_numerical_metrics = False
 
-    def should_be_timeseries(self):
+    def should_be_timeseries(self) -> bool:
         fd = self.form_data
         # TODO handle datasource-type-specific code in datasource
         conditions_met = (fd.get("granularity") and fd.get("granularity") != "all") or (
@@ -569,9 +573,9 @@ class TableViz(BaseViz):
             raise QueryObjectValidationError(
                 _("Pick a granularity in the Time section or " "uncheck 'Include Time'")
             )
-        return fd.get("include_time")
+        return bool(fd.get("include_time"))
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
 
@@ -660,7 +664,7 @@ class TableViz(BaseViz):
 
         return data
 
-    def json_dumps(self, obj, sort_keys=False):
+    def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
         return json.dumps(
             obj, default=utils.json_iso_dttm_ser, sort_keys=sort_keys, ignore_nan=True
         )
@@ -675,14 +679,14 @@ class TimeTableViz(BaseViz):
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
 
         if not fd.get("metrics"):
             raise QueryObjectValidationError(_("Pick at least one metric"))
 
-        if fd.get("groupby") and len(fd.get("metrics")) > 1:
+        if fd.get("groupby") and len(fd["metrics"]) > 1:
             raise QueryObjectValidationError(
                 _("When using 'Group By' you are limited to use a single metric")
             )
@@ -694,7 +698,7 @@ class TimeTableViz(BaseViz):
 
         fd = self.form_data
         columns = None
-        values = self.metric_labels
+        values: Union[List[str], str] = self.metric_labels
         if fd.get("groupby"):
             values = self.metric_labels[0]
             columns = fd.get("groupby")
@@ -717,7 +721,7 @@ class PivotTableViz(BaseViz):
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         groupby = self.form_data.get("groupby")
         columns = self.form_data.get("columns")
@@ -798,10 +802,10 @@ class MarkupViz(BaseViz):
     verbose_name = _("Markup")
     is_timeseries = False
 
-    def query_obj(self):
-        return None
+    def query_obj(self) -> QueryObjectDict:
+        return {}
 
-    def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+    def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
         return pd.DataFrame()
 
     def get_data(self, df: pd.DataFrame) -> VizData:
@@ -832,7 +836,7 @@ class WordCloudViz(BaseViz):
     verbose_name = _("Word Cloud")
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         d["groupby"] = [self.form_data.get("series")]
         return d
@@ -847,7 +851,7 @@ class TreemapViz(BaseViz):
     credits = '<a href="https://d3js.org">d3.js</a>'
     is_timeseries = False
 
-    def _nest(self, metric, df):
+    def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
         nlevels = df.index.nlevels
         if nlevels == 1:
             result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])]
@@ -927,7 +931,7 @@ class CalHeatmapViz(BaseViz):
             "range": range_,
         }
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
         d["metrics"] = fd.get("metrics")
@@ -953,19 +957,21 @@ class BoxPlotViz(NVD3Viz):
     sort_series = False
     is_timeseries = True
 
-    def to_series(self, df, classed="", title_suffix=""):
+    def to_series(
+        self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
+    ) -> List[Dict[str, Any]]:
         label_sep = " - "
         chart_data = []
         for index_value, row in zip(df.index, df.to_dict(orient="records")):
             if isinstance(index_value, tuple):
                 index_value = label_sep.join(index_value)
-            boxes = defaultdict(dict)
+            boxes: Dict[str, Dict[str, Any]] = defaultdict(dict)
             for (label, key), value in row.items():
                 if key == "nanmedian":
                     key = "Q2"
                 boxes[label][key] = value
             for label, box in boxes.items():
-                if len(self.form_data.get("metrics")) > 1:
+                if len(self.form_data["metrics"]) > 1:
                     # need to render data labels with metrics
                     chart_label = label_sep.join([index_value, label])
                 else:
@@ -980,46 +986,45 @@ class BoxPlotViz(NVD3Viz):
         form_data = self.form_data
 
         # conform to NVD3 names
-        def Q1(series):  # need to be named functions - can't use lambdas
+        def Q1(series: pd.Series) -> float:
+            # need to be named functions - can't use lambdas
             return np.nanpercentile(series, 25)
 
-        def Q3(series):
+        def Q3(series: pd.Series) -> float:
             return np.nanpercentile(series, 75)
 
         whisker_type = form_data.get("whisker_options")
         if whisker_type == "Tukey":
 
-            def whisker_high(series):
+            def whisker_high(series: pd.Series) -> float:
                 upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series))
                 return series[series <= upper_outer_lim].max()
 
-            def whisker_low(series):
+            def whisker_low(series: pd.Series) -> float:
                 lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series))
                 return series[series >= lower_outer_lim].min()
 
         elif whisker_type == "Min/max (no outliers)":
 
-            def whisker_high(series):
+            def whisker_high(series: pd.Series) -> float:
                 return series.max()
 
-            def whisker_low(series):
+            def whisker_low(series: pd.Series) -> float:
                 return series.min()
 
         elif " percentiles" in whisker_type:  # type: ignore
-            low, high = whisker_type.replace(" percentiles", "").split(  # type: ignore
-                "/"
-            )
+            low, high = cast(str, whisker_type).replace(" percentiles", "").split("/")
 
-            def whisker_high(series):
+            def whisker_high(series: pd.Series) -> float:
                 return np.nanpercentile(series, int(high))
 
-            def whisker_low(series):
+            def whisker_low(series: pd.Series) -> float:
                 return np.nanpercentile(series, int(low))
 
         else:
             raise ValueError("Unknown whisker type: {}".format(whisker_type))
 
-        def outliers(series):
+        def outliers(series: pd.Series) -> Set[float]:
             above = series[series > whisker_high(series)]
             below = series[series < whisker_low(series)]
             # pandas sometimes doesn't like getting lists back here
@@ -1039,7 +1044,7 @@ class BubbleViz(NVD3Viz):
     verbose_name = _("Bubble Chart")
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         form_data = self.form_data
         d = super().query_obj()
         d["groupby"] = [form_data.get("entity")]
@@ -1090,7 +1095,7 @@ class BulletViz(NVD3Viz):
     verbose_name = _("Bullet Chart")
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         form_data = self.form_data
         d = super().query_obj()
         self.metric = form_data["metric"]
@@ -1117,7 +1122,7 @@ class BigNumberViz(BaseViz):
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         metric = self.form_data.get("metric")
         if not metric:
@@ -1151,7 +1156,7 @@ class BigNumberTotalViz(BaseViz):
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         metric = self.form_data.get("metric")
         if not metric:
@@ -1174,7 +1179,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
     is_timeseries = True
     pivot_fill_value: Optional[int] = None
 
-    def to_series(self, df, classed="", title_suffix=""):
+    def to_series(
+        self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
+    ) -> List[Dict[str, Any]]:
         cols = []
         for col in df.columns:
             if col == "":
@@ -1191,6 +1198,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
             ys = series[name]
             if df[name].dtype.kind not in "biufc":
                 continue
+            series_title: Union[List[str], str, Tuple[str, ...]]
             if isinstance(name, list):
                 series_title = [str(title) for title in name]
             elif isinstance(name, tuple):
@@ -1207,7 +1215,9 @@ class NVD3TimeSeriesViz(NVD3Viz):
             if title_suffix:
                 if isinstance(series_title, str):
                     series_title = (series_title, title_suffix)
-                elif isinstance(series_title, (list, tuple)):
+                elif isinstance(series_title, list):
+                    series_title = series_title + [title_suffix]
+                elif isinstance(series_title, tuple):
                     series_title = series_title + (title_suffix,)
 
             values = []
@@ -1274,7 +1284,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
 
         return df
 
-    def run_extra_queries(self):
+    def run_extra_queries(self) -> None:
         fd = self.form_data
 
         time_compare = fd.get("time_compare") or []
@@ -1364,8 +1374,8 @@ class MultiLineViz(NVD3Viz):
 
     is_timeseries = True
 
-    def query_obj(self):
-        return None
+    def query_obj(self) -> QueryObjectDict:
+        return {}
 
     def get_data(self, df: pd.DataFrame) -> VizData:
         fd = self.form_data
@@ -1394,7 +1404,7 @@ class NVD3DualLineViz(NVD3Viz):
     sort_series = False
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         m1 = self.form_data.get("metric")
         m2 = self.form_data.get("metric_2")
@@ -1409,7 +1419,7 @@ class NVD3DualLineViz(NVD3Viz):
             )
         return d
 
-    def to_series(self, df, classed=""):
+    def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]:
         cols = []
         for col in df.columns:
             if col == "":
@@ -1421,7 +1431,7 @@ class NVD3DualLineViz(NVD3Viz):
         df.columns = cols
         series = df.to_dict("series")
         chart_data = []
-        metrics = [self.form_data.get("metric"), self.form_data.get("metric_2")]
+        metrics = [self.form_data["metric"], self.form_data["metric_2"]]
         for i, m in enumerate(metrics):
             m = utils.get_metric_name(m)
             ys = series[m]
@@ -1476,7 +1486,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
     sort_series = True
     verbose_name = _("Time Series - Period Pivot")
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         d["metrics"] = [self.form_data.get("metric")]
         return d
@@ -1561,7 +1571,7 @@ class HistogramViz(BaseViz):
     verbose_name = _("Histogram")
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         """Returns the query object for this visualization"""
         d = super().query_obj()
         d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"]))
@@ -1576,9 +1586,9 @@ class HistogramViz(BaseViz):
         d["groupby"] = []
         return d
 
-    def labelify(self, keys, column):
+    def labelify(self, keys: Union[List[str], str], column: str) -> str:
         if isinstance(keys, str):
-            keys = (keys,)
+            keys = [keys]
         # removing undesirable characters
         labels = [re.sub(r"\W+", r"_", k) for k in keys]
         if len(self.columns) > 1 or not self.groupby:
@@ -1617,7 +1627,7 @@ class DistributionBarViz(DistributionPieViz):
     verbose_name = _("Distribution - Bar Chart")
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
         if len(d["groupby"]) < len(fd.get("groupby") or []) + len(
@@ -1708,7 +1718,7 @@ class SunburstViz(BaseViz):
         df = df[cols]
         return df.to_numpy().tolist()
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         qry = super().query_obj()
         fd = self.form_data
         qry["metrics"] = [fd["metric"]]
@@ -1727,7 +1737,7 @@ class SankeyViz(BaseViz):
     is_timeseries = False
     credits = '<a href="https://www.npmjs.com/package/d3-sankey">d3-sankey on npm</a>'
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         qry = super().query_obj()
         if len(qry["groupby"]) != 2:
             raise QueryObjectValidationError(
@@ -1746,21 +1756,23 @@ class SankeyViz(BaseViz):
         for row in recs:
             hierarchy[row["source"]].add(row["target"])
 
-        def find_cycle(g):
+        def find_cycle(g: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]:
             """Whether there's a cycle in a directed graph"""
             path = set()
 
-            def visit(vertex):
+            def visit(vertex: str) -> Optional[Tuple[str, str]]:
                 path.add(vertex)
                 for neighbour in g.get(vertex, ()):
                     if neighbour in path or visit(neighbour):
                         return (vertex, neighbour)
                 path.remove(vertex)
+                return None
 
             for v in g:
                 cycle = visit(v)
                 if cycle:
                     return cycle
+            return None
 
         cycle = find_cycle(hierarchy)
         if cycle:
@@ -1782,7 +1794,7 @@ class DirectedForceViz(BaseViz):
     credits = 'd3noob @<a href="http://bl.ocks.org/d3noob/5141278">bl.ocks.org</a>'
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         qry = super().query_obj()
         if len(self.form_data["groupby"]) != 2:
             raise QueryObjectValidationError(_("Pick exactly 2 columns to 'Group By'"))
@@ -1803,7 +1815,7 @@ class ChordViz(BaseViz):
     credits = '<a href="https://github.com/d3/d3-chord">Bostock</a>'
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         qry = super().query_obj()
         fd = self.form_data
         qry["groupby"] = [fd.get("groupby"), fd.get("columns")]
@@ -1836,7 +1848,7 @@ class CountryMapViz(BaseViz):
     is_timeseries = False
     credits = "From bl.ocks.org By john-guerra"
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         qry = super().query_obj()
         qry["metrics"] = [self.form_data["metric"]]
         qry["groupby"] = [self.form_data["entity"]]
@@ -1863,7 +1875,7 @@ class WorldMapViz(BaseViz):
     is_timeseries = False
     credits = 'datamaps on <a href="https://www.npmjs.com/package/datamaps">npm</a>'
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         qry = super().query_obj()
         qry["groupby"] = [self.form_data["entity"]]
         return qry
@@ -1923,10 +1935,10 @@ class FilterBoxViz(BaseViz):
     cache_type = "get_data"
     filter_row_limit = 1000
 
-    def query_obj(self):
-        return None
+    def query_obj(self) -> QueryObjectDict:
+        return {}
 
-    def run_extra_queries(self):
+    def run_extra_queries(self) -> None:
         qry = super().query_obj()
         filters = self.form_data.get("filter_configs") or []
         qry["row_limit"] = self.filter_row_limit
@@ -1979,10 +1991,10 @@ class IFrameViz(BaseViz):
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
     is_timeseries = False
 
-    def query_obj(self):
-        return None
+    def query_obj(self) -> QueryObjectDict:
+        return {}
 
-    def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
+    def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
         return pd.DataFrame()
 
     def get_data(self, df: pd.DataFrame) -> VizData:
@@ -2005,7 +2017,7 @@ class ParallelCoordinatesViz(BaseViz):
     )
     is_timeseries = False
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
         d["groupby"] = [fd.get("series")]
@@ -2027,7 +2039,7 @@ class HeatmapViz(BaseViz):
         "bl.ocks.org</a>"
     )
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
         d["metrics"] = [fd.get("metric")]
@@ -2092,7 +2104,7 @@ class MapboxViz(BaseViz):
     is_timeseries = False
     credits = "<a href=https://www.mapbox.com/mapbox-gl-js/api/>Mapbox GL JS</a>"
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         fd = self.form_data
         label_col = fd.get("mapbox_label")
@@ -2124,22 +2136,24 @@ class MapboxViz(BaseViz):
                 label_col
                 and len(label_col) >= 1
                 and label_col[0] != "count"
-                and label_col[0] not in fd.get("groupby")
+                and label_col[0] not in fd["groupby"]
             ):
                 raise QueryObjectValidationError(
                     _("Choice of [Label] must be present in [Group By]")
                 )
 
-            if fd.get("point_radius") != "Auto" and fd.get(
-                "point_radius"
-            ) not in fd.get("groupby"):
+            if (
+                fd.get("point_radius") != "Auto"
+                and fd.get("point_radius") not in fd["groupby"]
+            ):
                 raise QueryObjectValidationError(
                     _("Choice of [Point Radius] must be present in [Group By]")
                 )
 
-            if fd.get("all_columns_x") not in fd.get("groupby") or fd.get(
-                "all_columns_y"
-            ) not in fd.get("groupby"):
+            if (
+                fd.get("all_columns_x") not in fd["groupby"]
+                or fd.get("all_columns_y") not in fd["groupby"]
+            ):
                 raise QueryObjectValidationError(
                     _(
                         "[Longitude] and [Latitude] columns must be present in "
@@ -2226,8 +2240,8 @@ class DeckGLMultiLayer(BaseViz):
     is_timeseries = False
     credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
 
-    def query_obj(self):
-        return None
+    def query_obj(self) -> QueryObjectDict:
+        return {}
 
     def get_data(self, df: pd.DataFrame) -> VizData:
         fd = self.form_data
@@ -2251,14 +2265,14 @@ class BaseDeckGLViz(BaseViz):
     credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
     spatial_control_keys: List[str] = []
 
-    def get_metrics(self):
+    def get_metrics(self) -> List[str]:
         self.metric = self.form_data.get("size")
         return [self.metric] if self.metric else []
 
-    def process_spatial_query_obj(self, key, group_by):
+    def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None:
         group_by.extend(self.get_spatial_columns(key))
 
-    def get_spatial_columns(self, key):
+    def get_spatial_columns(self, key: str) -> List[str]:
         spatial = self.form_data.get(key)
         if spatial is None:
             raise ValueError(_("Bad spatial key"))
@@ -2269,9 +2283,10 @@ class BaseDeckGLViz(BaseViz):
             return [spatial.get("lonlatCol")]
         elif spatial.get("type") == "geohash":
             return [spatial.get("geohashCol")]
+        return []
 
     @staticmethod
-    def parse_coordinates(s):
+    def parse_coordinates(s: Any) -> Optional[Tuple[float, float]]:
         if not s:
             return None
         try:
@@ -2281,15 +2296,15 @@ class BaseDeckGLViz(BaseViz):
             raise SpatialException(_("Invalid spatial point encountered: %s" % s))
 
     @staticmethod
-    def reverse_geohash_decode(geohash_code):
+    def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]:
         lat, lng = geohash.decode(geohash_code)
         return (lng, lat)
 
     @staticmethod
-    def reverse_latlong(df, key):
+    def reverse_latlong(df: pd.DataFrame, key: str) -> None:
         df[key] = [tuple(reversed(o)) for o in df[key] if isinstance(o, (list, tuple))]
 
-    def process_spatial_data_obj(self, key, df):
+    def process_spatial_data_obj(self, key: str, df: pd.DataFrame) -> pd.DataFrame:
         spatial = self.form_data.get(key)
         if spatial is None:
             raise ValueError(_("Bad spatial key"))
@@ -2321,7 +2336,7 @@ class BaseDeckGLViz(BaseViz):
             )
         return df
 
-    def add_null_filters(self):
+    def add_null_filters(self) -> None:
         fd = self.form_data
         spatial_columns = set()
         for key in self.spatial_control_keys:
@@ -2339,7 +2354,7 @@ class BaseDeckGLViz(BaseViz):
             filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""})
             fd["adhoc_filters"].append(filter_)
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         fd = self.form_data
 
         # add NULL filters
@@ -2347,16 +2362,16 @@ class BaseDeckGLViz(BaseViz):
             self.add_null_filters()
 
         d = super().query_obj()
-        gb = []
+        gb: List[str] = []
 
         for key in self.spatial_control_keys:
             self.process_spatial_query_obj(key, gb)
 
         if fd.get("dimension"):
-            gb += [fd.get("dimension")]
+            gb += [fd["dimension"]]
 
         if fd.get("js_columns"):
-            gb += fd.get("js_columns")
+            gb += fd.get("js_columns") or []
         metrics = self.get_metrics()
         gb = list(set(gb))
         if metrics:
@@ -2367,7 +2382,7 @@ class BaseDeckGLViz(BaseViz):
             d["columns"] = gb
         return d
 
-    def get_js_columns(self, d):
+    def get_js_columns(self, d: Dict[str, Any]) -> Dict[str, Any]:
         cols = self.form_data.get("js_columns") or []
         return {col: d.get(col) for col in cols}
 
@@ -2393,7 +2408,7 @@ class BaseDeckGLViz(BaseViz):
             "metricLabels": self.metric_labels,
         }
 
-    def get_properties(self, d):
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
         raise NotImplementedError()
 
 
@@ -2406,7 +2421,7 @@ class DeckScatterViz(BaseDeckGLViz):
     spatial_control_keys = ["spatial"]
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         fd = self.form_data
         self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
         self.point_radius_fixed = fd.get("point_radius_fixed") or {
@@ -2415,19 +2430,21 @@ class DeckScatterViz(BaseDeckGLViz):
         }
         return super().query_obj()
 
-    def get_metrics(self):
+    def get_metrics(self) -> List[str]:
         self.metric = None
         if self.point_radius_fixed.get("type") == "metric":
-            self.metric = self.point_radius_fixed.get("value")
+            self.metric = self.point_radius_fixed["value"]
             return [self.metric]
-        return None
+        return []
 
-    def get_properties(self, d):
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
         return {
-            "metric": d.get(self.metric_label),
+            "metric": d.get(self.metric_label) if self.metric_label else None,
             "radius": self.fixed_value
             if self.fixed_value
-            else d.get(self.metric_label),
+            else d.get(self.metric_label)
+            if self.metric_label
+            else None,
             "cat_color": d.get(self.dim) if self.dim else None,
             "position": d.get("spatial"),
             DTTM_ALIAS: d.get(DTTM_ALIAS),
@@ -2453,20 +2470,20 @@ class DeckScreengrid(BaseDeckGLViz):
     spatial_control_keys = ["spatial"]
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         fd = self.form_data
-        self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity")
+        self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
         return super().query_obj()
 
-    def get_properties(self, d):
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
         return {
             "position": d.get("spatial"),
-            "weight": d.get(self.metric_label) or 1,
+            "weight": (d.get(self.metric_label) if self.metric_label else None) or 1,
             "__timestamp": d.get(DTTM_ALIAS) or d.get("__time"),
         }
 
     def get_data(self, df: pd.DataFrame) -> VizData:
-        self.metric_label = utils.get_metric_name(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
         return super().get_data(df)
 
 
@@ -2478,15 +2495,18 @@ class DeckGrid(BaseDeckGLViz):
     verbose_name = _("Deck.gl - 3D Grid")
     spatial_control_keys = ["spatial"]
 
-    def get_properties(self, d):
-        return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
+        return {
+            "position": d.get("spatial"),
+            "weight": (d.get(self.metric_label) if self.metric_label else None) or 1,
+        }
 
     def get_data(self, df: pd.DataFrame) -> VizData:
-        self.metric_label = utils.get_metric_name(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
         return super().get_data(df)
 
 
-def geohash_to_json(geohash_code):
+def geohash_to_json(geohash_code: str) -> List[List[float]]:
     p = geohash.bbox(geohash_code)
     return [
         [p.get("w"), p.get("n")],
@@ -2511,9 +2531,9 @@ class DeckPathViz(BaseDeckGLViz):
         "geohash": geohash_to_json,
     }
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         fd = self.form_data
-        self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity")
+        self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
         d = super().query_obj()
         self.metric = fd.get("metric")
         line_col = fd.get("line_column")
@@ -2525,11 +2545,11 @@ class DeckPathViz(BaseDeckGLViz):
             d["columns"].append(line_col)
         return d
 
-    def get_properties(self, d):
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
         fd = self.form_data
-        line_type = fd.get("line_type")
+        line_type = fd["line_type"]
         deser = self.deser_map[line_type]
-        line_column = fd.get("line_column")
+        line_column = fd["line_column"]
         path = deser(d[line_column])
         if fd.get("reverse_long_lat"):
             path = [(o[1], o[0]) for o in path]
@@ -2540,7 +2560,7 @@ class DeckPathViz(BaseDeckGLViz):
         return d
 
     def get_data(self, df: pd.DataFrame) -> VizData:
-        self.metric_label = utils.get_metric_name(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
         return super().get_data(df)
 
 
@@ -2552,18 +2572,18 @@ class DeckPolygon(DeckPathViz):
     deck_viz_key = "polygon"
     verbose_name = _("Deck.gl - Polygon")
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         fd = self.form_data
         self.elevation = fd.get("point_radius_fixed") or {"type": "fix", "value": 500}
         return super().query_obj()
 
-    def get_metrics(self):
+    def get_metrics(self) -> List[str]:
         metrics = [self.form_data.get("metric")]
         if self.elevation.get("type") == "metric":
             metrics.append(self.elevation.get("value"))
         return [metric for metric in metrics if metric]
 
-    def get_properties(self, d):
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
         super().get_properties(d)
         fd = self.form_data
         elevation = fd["point_radius_fixed"]["value"]
@@ -2582,11 +2602,14 @@ class DeckHex(BaseDeckGLViz):
     verbose_name = _("Deck.gl - 3D HEX")
     spatial_control_keys = ["spatial"]
 
-    def get_properties(self, d):
-        return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1}
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
+        return {
+            "position": d.get("spatial"),
+            "weight": (d.get(self.metric_label) if self.metric_label else None) or 1,
+        }
 
     def get_data(self, df: pd.DataFrame) -> VizData:
-        self.metric_label = utils.get_metric_name(self.metric)
+        self.metric_label = utils.get_metric_name(self.metric) if self.metric else None
         return super(DeckHex, self).get_data(df)
 
 
@@ -2597,15 +2620,15 @@ class DeckGeoJson(BaseDeckGLViz):
     viz_type = "deck_geojson"
     verbose_name = _("Deck.gl - GeoJSON")
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
         d["columns"] += [self.form_data.get("geojson")]
         d["metrics"] = []
         d["groupby"] = []
         return d
 
-    def get_properties(self, d):
-        geojson = d.get(self.form_data.get("geojson"))
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
+        geojson = d[self.form_data["geojson"]]
         return json.loads(geojson)
 
 
@@ -2618,12 +2641,12 @@ class DeckArc(BaseDeckGLViz):
     spatial_control_keys = ["start_spatial", "end_spatial"]
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         fd = self.form_data
         self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
         return super().query_obj()
 
-    def get_properties(self, d):
+    def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]:
         dim = self.form_data.get("dimension")
         return {
             "sourcePosition": d.get("start_spatial"),
@@ -2653,15 +2676,15 @@ class EventFlowViz(BaseViz):
     credits = 'from <a href="https://github.com/williaster/data-ui">@data-ui</a>'
     is_timeseries = True
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         query = super().query_obj()
         form_data = self.form_data
 
-        event_key = form_data.get("all_columns_x")
-        entity_key = form_data.get("entity")
+        event_key = form_data["all_columns_x"]
+        entity_key = form_data["entity"]
         meta_keys = [
             col
-            for col in form_data.get("all_columns")
+            for col in form_data["all_columns"]
             if col != event_key and col != entity_key
         ]
 
@@ -2773,14 +2796,16 @@ class PartitionViz(NVD3TimeSeriesViz):
     viz_type = "partition"
     verbose_name = _("Partition Diagram")
 
-    def query_obj(self):
+    def query_obj(self) -> QueryObjectDict:
         query_obj = super().query_obj()
         time_op = self.form_data.get("time_series_option", "not_time")
         # Return time series data if the user specifies so
         query_obj["is_timeseries"] = time_op != "not_time"
         return query_obj
 
-    def levels_for(self, time_op, groups, df):
+    def levels_for(
+        self, time_op: str, groups: List[str], df: pd.DataFrame
+    ) -> Dict[int, pd.Series]:
         """
         Compute the partition at each `level` from the dataframe.
         """
@@ -2794,7 +2819,9 @@ class PartitionViz(NVD3TimeSeriesViz):
             )
         return levels
 
-    def levels_for_diff(self, time_op, groups, df):
+    def levels_for_diff(
+        self, time_op: str, groups: List[str], df: pd.DataFrame
+    ) -> Dict[int, pd.DataFrame]:
         # Obtain a unique list of the time grains
         times = list(set(df[DTTM_ALIAS]))
         times.sort()
@@ -2828,7 +2855,9 @@ class PartitionViz(NVD3TimeSeriesViz):
             )
         return levels
 
-    def levels_for_time(self, groups, df):
+    def levels_for_time(
+        self, groups: List[str], df: pd.DataFrame
+    ) -> Dict[int, VizData]:
         procs = {}
         for i in range(0, len(groups) + 1):
             self.form_data["groupby"] = groups[:i]
@@ -2837,11 +2866,19 @@ class PartitionViz(NVD3TimeSeriesViz):
         self.form_data["groupby"] = groups
         return procs
 
-    def nest_values(self, levels, level=0, metric=None, dims=()):
+    def nest_values(
+        self,
+        levels: Dict[int, pd.DataFrame],
+        level: int = 0,
+        metric: Optional[str] = None,
+        dims: Optional[List[str]] = None,
+    ) -> List[Dict[str, Any]]:
         """
         Nest values at each level on the back-end with
         access and setting, instead of summing from the bottom.
         """
+        if dims is None:
+            dims = []
         if not level:
             return [
                 {
@@ -2856,7 +2893,7 @@ class PartitionViz(NVD3TimeSeriesViz):
                 {
                     "name": i,
                     "val": levels[1][metric][i],
-                    "children": self.nest_values(levels, 2, metric, (i,)),
+                    "children": self.nest_values(levels, 2, metric, [i]),
                 }
                 for i in levels[1][metric].index
             ]
@@ -2866,12 +2903,20 @@ class PartitionViz(NVD3TimeSeriesViz):
             {
                 "name": i,
                 "val": levels[level][metric][dims][i],
-                "children": self.nest_values(levels, level + 1, metric, dims + (i,)),
+                "children": self.nest_values(levels, level + 1, metric, dims + [i]),
             }
             for i in levels[level][metric][dims].index
         ]
 
-    def nest_procs(self, procs, level=-1, dims=(), time=None):
+    def nest_procs(
+        self,
+        procs: Dict[int, pd.DataFrame],
+        level: int = -1,
+        dims: Optional[Tuple[str, ...]] = None,
+        time: Any = None,
+    ) -> List[Dict[str, Any]]:
+        if dims is None:
+            dims = ()
         if level == -1:
             return [
                 {"name": m, "children": self.nest_procs(procs, 0, (m,))}
diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py
index f51580d..1df2528 100644
--- a/superset/viz_sip38.py
+++ b/superset/viz_sip38.py
@@ -20,6 +20,7 @@
 These objects represent the backend of all the visualizations that
 Superset can render.
 """
+# mypy: ignore-errors
 import copy
 import hashlib
 import inspect
@@ -610,7 +611,7 @@ class TableViz(BaseViz):
             raise QueryObjectValidationError(
                 _("Pick a granularity in the Time section or " "uncheck 'Include Time'")
             )
-        return fd.get("include_time")
+        return bool(fd.get("include_time"))
 
     def query_obj(self):
         d = super().query_obj()
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 748d50b..f8eb8ce 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -974,7 +974,7 @@ class BaseDeckGLVizTestCase(SupersetTestCase):
         test_viz_deckgl = viz.DeckScatterViz(datasource, form_data)
         test_viz_deckgl.point_radius_fixed = {}
         result = test_viz_deckgl.get_metrics()
-        assert result is None
+        assert result == []
 
     def test_get_js_columns(self):
         form_data = load_fixture("deck_path_form_data.json")