You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2019/12/31 07:26:39 UTC
[incubator-superset] branch master updated: chore: refactor,
add typing and fix uncovered errors (#8900)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5b690f9 chore: refactor, add typing and fix uncovered errors (#8900)
5b690f9 is described below
commit 5b690f94116b5b77c0af566d24f298cc6db469b9
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Dec 31 09:26:23 2019 +0200
chore: refactor, add typing and fix uncovered errors (#8900)
* Add type annotations and fix inconsistencies
* Address review comments
* Remove incorrect typing of jsonable obj
---
superset/app.py | 2 +-
superset/common/query_context.py | 31 +++++-------
superset/config.py | 2 +-
superset/connectors/base/models.py | 18 ++++---
superset/connectors/connector_registry.py | 6 +--
superset/connectors/sqla/models.py | 59 +++++++++++++---------
superset/dataframe.py | 2 +-
superset/db_engine_specs/base.py | 2 +-
...e_form_strip_leading_and_trailing_whitespace.py | 6 +--
.../versions/c617da68de7d_form_nullable.py | 6 +--
.../migrations/versions/d94d33dbe938_form_strip.py | 6 +--
superset/models/core.py | 2 +-
superset/models/helpers.py | 4 +-
superset/models/tags.py | 4 +-
superset/sql_parse.py | 2 +-
superset/stats_logger.py | 2 +-
superset/views/annotations.py | 2 +-
superset/views/base.py | 7 +--
superset/views/datasource.py | 38 ++++++++------
superset/viz.py | 37 ++++++++------
tests/sqla_models_tests.py | 4 +-
tests/viz_tests.py | 4 +-
22 files changed, 137 insertions(+), 109 deletions(-)
diff --git a/superset/app.py b/superset/app.py
index abc6636..a037e64 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -221,7 +221,7 @@ class SupersetAppInitializer:
if self.config["ENABLE_CHUNK_ENCODING"]:
- class ChunkedEncodingFix(object): # pylint: disable=too-few-public-methods
+ class ChunkedEncodingFix: # pylint: disable=too-few-public-methods
def __init__(self, app):
self.app = app
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index d9d6358..b234c14 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -17,7 +17,7 @@
import logging
import pickle as pkl
from datetime import datetime, timedelta
-from typing import Any, Dict, List, Optional
+from typing import Any, ClassVar, Dict, List, Optional
import numpy as np
import pandas as pd
@@ -41,8 +41,8 @@ class QueryContext:
to retrieve the data payload for a given viz.
"""
- cache_type: str = "df"
- enforce_numerical_metrics: bool = True
+ cache_type: ClassVar[str] = "df"
+ enforce_numerical_metrics: ClassVar[bool] = True
datasource: BaseDatasource
queries: List[QueryObject]
@@ -53,20 +53,16 @@ class QueryContext:
# a vanilla python type https://github.com/python/mypy/issues/5288
def __init__(
self,
- datasource: Dict,
- queries: List[Dict],
+ datasource: Dict[str, Any],
+ queries: List[Dict[str, Any]],
force: bool = False,
custom_cache_timeout: Optional[int] = None,
) -> None:
- self.datasource = ConnectorRegistry.get_datasource( # type: ignore
- datasource.get("type"), # type: ignore
- int(datasource.get("id")), # type: ignore
- db.session,
+ self.datasource = ConnectorRegistry.get_datasource(
+ str(datasource["type"]), int(datasource["id"]), db.session
)
- self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))
-
+ self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
-
self.custom_cache_timeout = custom_cache_timeout
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
@@ -78,7 +74,7 @@ class QueryContext:
timestamp_format = None
if self.datasource.type == "table":
- dttm_col = self.datasource.get_col(query_object.granularity)
+ dttm_col = self.datasource.get_column(query_object.granularity)
if dttm_col:
timestamp_format = dttm_col.python_date_format
@@ -115,17 +111,18 @@ class QueryContext:
"df": df,
}
+ @staticmethod
def df_metrics_to_num( # pylint: disable=invalid-name,no-self-use
- self, df: pd.DataFrame, query_object: QueryObject
+ df: pd.DataFrame, query_object: QueryObject
) -> None:
"""Converting metrics to numeric when pandas.read_sql cannot"""
- metrics = [metric for metric in query_object.metrics]
for col, dtype in df.dtypes.items():
- if dtype.type == np.object_ and col in metrics:
+ if dtype.type == np.object_ and col in query_object.metrics:
df[col] = pd.to_numeric(df[col], errors="coerce")
+ @staticmethod
def get_data( # pylint: disable=invalid-name,no-self-use
- self, df: pd.DataFrame
+ df: pd.DataFrame
) -> List[Dict]:
return df.to_dict(orient="records")
diff --git a/superset/config.py b/superset/config.py
index 94ad245..1e0fa27 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -449,7 +449,7 @@ WARNING_MSG = None
# http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html
-class CeleryConfig(object): # pylint: disable=too-few-public-methods
+class CeleryConfig: # pylint: disable=too-few-public-methods
BROKER_URL = "sqla+sqlite:///celerydb.sqlite"
CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks")
CELERY_RESULT_BACKEND = "db+sqlite:///celery_results.sqlite"
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index c4f7ba1..2f06245 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Dict, Hashable, List, Optional, Type
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
@@ -44,7 +44,7 @@ class BaseDatasource(
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
column_class: Optional[Type] = None # link to derivative of BaseColumn
metric_class: Optional[Type] = None # link to derivative of BaseMetric
- owner_class = None
+ owner_class: Optional[User] = None
# Used to do code highlighting when displaying the query in the UI
query_language: Optional[str] = None
@@ -342,11 +342,14 @@ class BaseDatasource(
obj.get("columns"), self.columns, self.column_class, "column_name"
)
- def get_extra_cache_keys( # pylint: disable=unused-argument,no-self-use
- self, query_obj: Dict
- ) -> List[Any]:
+ def get_extra_cache_keys( # pylint: disable=no-self-use
+ self, query_obj: Dict[str, Any] # pylint: disable=unused-argument
+ ) -> List[Hashable]:
""" If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
+
+ :param query_obj: The dict representation of a query object
+ :return: list of keys
"""
return []
@@ -404,6 +407,10 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@property
+ def python_date_format(self):
+ raise NotImplementedError()
+
+ @property
def data(self) -> Dict[str, Any]:
attrs = (
"id",
@@ -415,7 +422,6 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
"groupby",
"is_dttm",
"type",
- "python_date_format",
)
return {s: getattr(self, s) for s in attrs if hasattr(self, s)}
diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py
index 2cf79e7..4097066 100644
--- a/superset/connectors/connector_registry.py
+++ b/superset/connectors/connector_registry.py
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
-class ConnectorRegistry(object):
+class ConnectorRegistry:
""" Central Registry for all available datasource engines"""
sources: Dict[str, Type["BaseDatasource"]] = {}
@@ -43,11 +43,11 @@ class ConnectorRegistry(object):
@classmethod
def get_datasource(
cls, datasource_type: str, datasource_id: int, session: Session
- ) -> Optional["BaseDatasource"]:
+ ) -> "BaseDatasource":
return (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
- .first()
+ .one()
)
@classmethod
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 22d5c85..4930a30 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -19,7 +19,7 @@ import logging
import re
from collections import OrderedDict
from datetime import datetime
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
+from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
import pandas as pd
import sqlalchemy as sa
@@ -84,7 +84,7 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
- def query(self, query_obj: Dict) -> QueryResult:
+ def query(self, query_obj: Dict[str, Any]) -> QueryResult:
df = None
error_message = None
qry = db.session.query(Annotation)
@@ -537,16 +537,9 @@ class SqlaTable(Model, BaseDatasource):
latest_partition=False,
)
- def get_col(self, col_name: str) -> Optional[Column]:
- columns = self.columns
- for col in columns:
- if col_name == col.column_name:
- return col
- return None
-
@property
def data(self) -> Dict:
- d = super(SqlaTable, self).data
+ d = super().data
if self.type == "table":
grains = self.database.grains() or []
if grains:
@@ -598,7 +591,7 @@ class SqlaTable(Model, BaseDatasource):
def get_template_processor(self, **kwargs):
return get_template_processor(table=self, database=self.database, **kwargs)
- def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
+ def get_query_str_extended(self, query_obj: Dict[str, Any]) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logging.info(sql)
@@ -608,7 +601,7 @@ class SqlaTable(Model, BaseDatasource):
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
)
- def get_query_str(self, query_obj: Dict) -> str:
+ def get_query_str(self, query_obj: Dict[str, Any]) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
@@ -976,14 +969,23 @@ class SqlaTable(Model, BaseDatasource):
return or_(*groups)
- def query(self, query_obj: Dict) -> QueryResult:
+ def query(self, query_obj: Dict[str, Any]) -> QueryResult:
qry_start_dttm = datetime.now()
query_str_ext = self.get_query_str_extended(query_obj)
sql = query_str_ext.sql
status = utils.QueryStatus.SUCCESS
error_message = None
- def mutator(df):
+ def mutator(df: pd.DataFrame) -> None:
+ """
+ Some engines change the case or generate bespoke column names, either by
+ default or due to lack of support for aliasing. This function ensures that
+ the column names in the DataFrame correspond to what is expected by
+ the viz components.
+
+ :param df: Original DataFrame returned by the engine
+ """
+
labels_expected = query_str_ext.labels_expected
if df is not None and not df.empty:
if len(df.columns) != len(labels_expected):
@@ -993,7 +995,6 @@ class SqlaTable(Model, BaseDatasource):
)
else:
df.columns = labels_expected
- return df
try:
df = self.database.get_df(sql, self.schema, mutator)
@@ -1135,13 +1136,16 @@ class SqlaTable(Model, BaseDatasource):
def default_query(qry) -> Query:
return qry.filter_by(is_sqllab_view=False)
- def has_extra_cache_keys(self, query_obj: Dict) -> bool:
+ def has_calls_to_cache_key_wrapper(self, query_obj: Dict[str, Any]) -> bool:
"""
- Detects the presence of calls to cache_key_wrapper in items in query_obj that can
- be templated.
+ Detects the presence of calls to `cache_key_wrapper` in items in query_obj that
+ can be templated. If any are present, the query must be evaluated to extract
+ additional keys for the cache key. This method is needed to avoid executing
+ the template code unnecessarily, as it may contain expensive calls, e.g. to
+ extract the latest partition of a database.
:param query_obj: query object to analyze
- :return: True if at least one item calls cache_key_wrapper, otherwise False
+ :return: True if at least one item calls `cache_key_wrapper`, otherwise False
"""
regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}")
templatable_statements: List[str] = []
@@ -1159,12 +1163,19 @@ class SqlaTable(Model, BaseDatasource):
return True
return False
- def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
- if self.has_extra_cache_keys(query_obj):
+ def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
+ """
+ The cache key of a SqlaTable needs to consider any keys added by the parent class
+ and any keys added via `cache_key_wrapper`.
+
+ :param query_obj: query object to analyze
+ :return: True if at least one item calls `cache_key_wrapper`, otherwise False
+ """
+ extra_cache_keys = super().get_extra_cache_keys(query_obj)
+ if self.has_calls_to_cache_key_wrapper(query_obj):
sqla_query = self.get_sqla_query(**query_obj)
- extra_cache_keys = sqla_query.extra_cache_keys
- return extra_cache_keys
- return []
+ extra_cache_keys += sqla_query.extra_cache_keys
+ return extra_cache_keys
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 9c36efc..1163a05 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -68,7 +68,7 @@ def is_numeric(dtype):
return np.issubdtype(dtype, np.number)
-class SupersetDataFrame(object):
+class SupersetDataFrame:
# Mapping numpy dtype.char to generic database types
type_map = {
"b": "BOOL", # boolean
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 0f33758..90dd89b 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -104,7 +104,7 @@ def compile_timegrain_expression(
return element.name.replace("{col}", compiler.process(element.col, **kw))
-class LimitMethod(object): # pylint: disable=too-few-public-methods
+class LimitMethod: # pylint: disable=too-few-public-methods
"""Enum the ways that limits can be applied"""
FETCH_MANY = "fetch_many"
diff --git a/superset/migrations/versions/258b5280a45e_form_strip_leading_and_trailing_whitespace.py b/superset/migrations/versions/258b5280a45e_form_strip_leading_and_trailing_whitespace.py
index ddf7295..0aa924d 100644
--- a/superset/migrations/versions/258b5280a45e_form_strip_leading_and_trailing_whitespace.py
+++ b/superset/migrations/versions/258b5280a45e_form_strip_leading_and_trailing_whitespace.py
@@ -33,7 +33,7 @@ from superset.utils.core import MediumText
Base = declarative_base()
-class BaseColumnMixin(object):
+class BaseColumnMixin:
id = Column(Integer, primary_key=True)
column_name = Column(String(255))
description = Column(Text)
@@ -41,12 +41,12 @@ class BaseColumnMixin(object):
verbose_name = Column(String(1024))
-class BaseDatasourceMixin(object):
+class BaseDatasourceMixin:
id = Column(Integer, primary_key=True)
description = Column(Text)
-class BaseMetricMixin(object):
+class BaseMetricMixin:
id = Column(Integer, primary_key=True)
d3format = Column(String(128))
description = Column(Text)
diff --git a/superset/migrations/versions/c617da68de7d_form_nullable.py b/superset/migrations/versions/c617da68de7d_form_nullable.py
index d66ce29..c5ffd72 100644
--- a/superset/migrations/versions/c617da68de7d_form_nullable.py
+++ b/superset/migrations/versions/c617da68de7d_form_nullable.py
@@ -36,7 +36,7 @@ from superset.utils.core import MediumText
Base = declarative_base()
-class BaseColumnMixin(object):
+class BaseColumnMixin:
id = Column(Integer, primary_key=True)
column_name = Column(String(255))
description = Column(Text)
@@ -44,12 +44,12 @@ class BaseColumnMixin(object):
verbose_name = Column(String(1024))
-class BaseDatasourceMixin(object):
+class BaseDatasourceMixin:
id = Column(Integer, primary_key=True)
description = Column(Text)
-class BaseMetricMixin(object):
+class BaseMetricMixin:
id = Column(Integer, primary_key=True)
d3format = Column(String(128))
description = Column(Text)
diff --git a/superset/migrations/versions/d94d33dbe938_form_strip.py b/superset/migrations/versions/d94d33dbe938_form_strip.py
index c1882f6..a899fd5 100644
--- a/superset/migrations/versions/d94d33dbe938_form_strip.py
+++ b/superset/migrations/versions/d94d33dbe938_form_strip.py
@@ -36,7 +36,7 @@ from superset.utils.core import MediumText
Base = declarative_base()
-class BaseColumnMixin(object):
+class BaseColumnMixin:
id = Column(Integer, primary_key=True)
column_name = Column(String(255))
description = Column(Text)
@@ -44,12 +44,12 @@ class BaseColumnMixin(object):
verbose_name = Column(String(1024))
-class BaseDatasourceMixin(object):
+class BaseDatasourceMixin:
id = Column(Integer, primary_key=True)
description = Column(Text)
-class BaseMetricMixin(object):
+class BaseMetricMixin:
id = Column(Integer, primary_key=True)
d3format = Column(String(128))
description = Column(Text)
diff --git a/superset/models/core.py b/superset/models/core.py
index a262b4d..5730add 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -361,7 +361,7 @@ class Database(
)
if mutator:
- df = mutator(df)
+ mutator(df)
for k, v in df.dtypes.items():
if v.type == numpy.object_ and needs_conversion(df[k]):
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 97216b9..b11fb38 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -47,7 +47,7 @@ def json_to_dict(json_str):
return {}
-class ImportMixin(object):
+class ImportMixin:
export_parent: Optional[str] = None
# The name of the attribute
# with the SQL Alchemy back reference
@@ -361,7 +361,7 @@ class AuditMixinNullable(AuditMixin):
return Markup(f'<span class="no-wrap">{self.changed_on_humanized}</span>')
-class QueryResult(object): # pylint: disable=too-few-public-methods
+class QueryResult: # pylint: disable=too-few-public-methods
"""Object returned by the query interface"""
diff --git a/superset/models/tags.py b/superset/models/tags.py
index 6c29430..779113b 100644
--- a/superset/models/tags.py
+++ b/superset/models/tags.py
@@ -103,7 +103,7 @@ def get_object_type(class_name):
raise Exception("No mapping found for {0}".format(class_name))
-class ObjectUpdater(object):
+class ObjectUpdater:
object_type: Optional[str] = None
@@ -204,7 +204,7 @@ class QueryUpdater(ObjectUpdater):
return [target.user_id]
-class FavStarUpdater(object):
+class FavStarUpdater:
@classmethod
def after_insert(cls, mapper, connection, target):
# pylint: disable=unused-argument
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index b836726..3a1baab 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -49,7 +49,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None
-class ParsedQuery(object):
+class ParsedQuery:
def __init__(self, sql_statement):
self.sql: str = sql_statement
self._table_names: Set[str] = set()
diff --git a/superset/stats_logger.py b/superset/stats_logger.py
index e2d483f..aca511a 100644
--- a/superset/stats_logger.py
+++ b/superset/stats_logger.py
@@ -19,7 +19,7 @@ import logging
from colorama import Fore, Style
-class BaseStatsLogger(object):
+class BaseStatsLogger:
"""Base class for logging realtime events"""
def __init__(self, prefix="superset"):
diff --git a/superset/views/annotations.py b/superset/views/annotations.py
index d07344c..757e412 100644
--- a/superset/views/annotations.py
+++ b/superset/views/annotations.py
@@ -24,7 +24,7 @@ from superset.models.annotations import Annotation, AnnotationLayer
from .base import DeleteMixin, SupersetModelView
-class StartEndDttmValidator(object): # pylint: disable=too-few-public-methods
+class StartEndDttmValidator: # pylint: disable=too-few-public-methods
"""
Validates dttm fields.
"""
diff --git a/superset/views/base.py b/superset/views/base.py
index 15f711e..3a5dd64 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -167,7 +167,8 @@ def get_user_roles():
class BaseSupersetView(BaseView):
- def json_response(self, obj, status=200): # pylint: disable=no-self-use
+ @staticmethod
+ def json_response(obj, status=200) -> Response: # pylint: disable=no-self-use
return Response(
json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True),
status=status,
@@ -263,7 +264,7 @@ def validate_json(_form, field):
raise Exception(_("json isn't valid"))
-class YamlExportMixin(object): # pylint: disable=too-few-public-methods
+class YamlExportMixin: # pylint: disable=too-few-public-methods
"""
Override this if you want a dict response instead, with a certain key.
Used on DatabaseView for cli compatibility
@@ -286,7 +287,7 @@ class YamlExportMixin(object): # pylint: disable=too-few-public-methods
)
-class DeleteMixin(object): # pylint: disable=too-few-public-methods
+class DeleteMixin: # pylint: disable=too-few-public-methods
def _delete(self, primary_key):
"""
Delete function logic, override to implement diferent logic
diff --git a/superset/views/datasource.py b/superset/views/datasource.py
index 5eb7fa3..1f28238 100644
--- a/superset/views/datasource.py
+++ b/superset/views/datasource.py
@@ -17,9 +17,10 @@
import json
from collections import Counter
-from flask import request
+from flask import request, Response
from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
+from sqlalchemy.orm.exc import NoResultFound
from superset import appbuilder, db
from superset.connectors.connector_registry import ConnectorRegistry
@@ -35,15 +36,19 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
- def save(self):
- datasource = json.loads(request.form.get("data"))
+ def save(self) -> Response:
+ data = request.form.get("data")
+ if not isinstance(data, str):
+ return json_error_response("Request missing data field.", status="500")
+
+ datasource = json.loads(data)
datasource_id = datasource.get("id")
datasource_type = datasource.get("type")
orm_datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
- if "owners" in datasource:
+ if "owners" in datasource and orm_datasource.owner_class is not None:
datasource["owners"] = (
db.session.query(orm_datasource.owner_class)
.filter(orm_datasource.owner_class.id.in_(datasource["owners"]))
@@ -72,23 +77,24 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
- def get(self, datasource_type, datasource_id):
- orm_datasource = ConnectorRegistry.get_datasource(
- datasource_type, datasource_id, db.session
- )
-
- if not orm_datasource:
+ def get(self, datasource_type: str, datasource_id: int) -> Response:
+ try:
+ orm_datasource = ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, db.session
+ )
+ if not orm_datasource.data:
+ return json_error_response(
+ "Error fetching datasource data.", status="500"
+ )
+ return self.json_response(orm_datasource.data)
+ except NoResultFound:
return json_error_response("This datasource does not exist", status="400")
- elif not orm_datasource.data:
- return json_error_response("Error fetching datasource data.", status="500")
-
- return self.json_response(orm_datasource.data)
@expose("/external_metadata/<datasource_type>/<datasource_id>/")
@has_access_api
@api
@handle_api_exception
- def external_metadata(self, datasource_type=None, datasource_id=None):
+ def external_metadata(self, datasource_type: str, datasource_id: int) -> Response:
"""Gets column info from the source system"""
if datasource_type == "druid":
datasource = ConnectorRegistry.get_datasource(
@@ -104,6 +110,8 @@ class Datasource(BaseSupersetView):
table_name=request.args.get("table_name"),
schema=request.args.get("schema") or None,
)
+ else:
+ raise Exception(f"Unsupported datasource_type: {datasource_type}")
external_metadata = datasource.external_metadata()
return self.json_response(external_metadata)
diff --git a/superset/viz.py b/superset/viz.py
index 0efb0d3..beff59d 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -32,7 +32,7 @@ from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from functools import reduce
from itertools import product
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import geohash
import numpy as np
@@ -49,6 +49,7 @@ from pandas.tseries.frequencies import to_offset
from superset import app, cache, get_css_manifest_files
from superset.constants import NULL_STRING
from superset.exceptions import NullValueException, SpatialException
+from superset.models.helpers import QueryResult
from superset.utils import core as utils
from superset.utils.core import (
DTTM_ALIAS,
@@ -57,6 +58,9 @@ from superset.utils.core import (
to_adhoc,
)
+if TYPE_CHECKING:
+ from superset.connectors.base.models import BaseDatasource
+
config = app.config
stats_logger = config["STATS_LOGGER"]
relative_start = config["DEFAULT_RELATIVE_START_TIME"]
@@ -74,7 +78,7 @@ METRIC_KEYS = [
]
-class BaseViz(object):
+class BaseViz:
"""All visualizations derive this base class"""
@@ -85,7 +89,12 @@ class BaseViz(object):
cache_type = "df"
enforce_numerical_metrics = True
- def __init__(self, datasource, form_data, force=False):
+ def __init__(
+ self,
+ datasource: "BaseDatasource",
+ form_data: Dict[str, Any],
+ force: bool = False,
+ ):
if not datasource:
raise Exception(_("Viz is missing a datasource"))
@@ -102,7 +111,7 @@ class BaseViz(object):
self.status = None
self.error_msg = ""
- self.results = None
+ self.results: Optional[QueryResult] = None
self.error_message = None
self.force = force
@@ -110,10 +119,9 @@ class BaseViz(object):
# this is useful to trigger the <CachedLabel /> when
# in the cases where visualization have many queries
# (FilterBox for instance)
- self._some_from_cache = False
- self._any_cache_key = None
- self._any_cached_dttm = None
- self._extra_chart_data = []
+ self._any_cache_key: Optional[str] = None
+ self._any_cached_dttm: Optional[str] = None
+ self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
self.process_metrics()
@@ -195,9 +203,9 @@ class BaseViz(object):
timestamp_format = None
if self.datasource.type == "table":
- dttm_col = self.datasource.get_col(query_obj["granularity"])
- if dttm_col:
- timestamp_format = dttm_col.python_date_format
+ granularity_col = self.datasource.get_column(query_obj["granularity"])
+ if granularity_col:
+ timestamp_format = granularity_col.python_date_format
# The datasource here can be different backend but the interface is common
self.results = self.datasource.query(query_obj)
@@ -258,17 +266,14 @@ class BaseViz(object):
merge_extra_filters(self.form_data)
utils.split_adhoc_filters_into_base_filters(self.form_data)
- def query_obj(self):
+ def query_obj(self) -> Dict[str, Any]:
"""Building a query object"""
form_data = self.form_data
self.process_query_filters()
gb = form_data.get("groupby") or []
metrics = self.all_metrics or []
columns = form_data.get("columns") or []
- groupby = []
- for o in gb + columns:
- if o not in groupby:
- groupby.append(o)
+ groupby = list(set(gb + columns))
is_timeseries = self.is_timeseries
if DTTM_ALIAS in groupby:
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 0124bd0..b927ce8 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -60,7 +60,7 @@ class DatabaseModelTestCase(SupersetTestCase):
"extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
- self.assertTrue(table.has_extra_cache_keys(query_obj))
+ self.assertTrue(table.has_calls_to_cache_key_wrapper(query_obj))
self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])
def test_has_no_extra_cache_keys(self):
@@ -81,5 +81,5 @@ class DatabaseModelTestCase(SupersetTestCase):
"extras": {"where": "(user != 'abc')"},
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
- self.assertFalse(table.has_extra_cache_keys(query_obj))
+ self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj))
self.assertListEqual(extra_cache_keys, [])
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index ad94e34..b9edb56 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -100,7 +100,7 @@ class BaseVizTestCase(SupersetTestCase):
datasource.type = "table"
datasource.query = Mock(return_value=results)
mock_dttm_col = Mock()
- datasource.get_col = Mock(return_value=mock_dttm_col)
+ datasource.get_column = Mock(return_value=mock_dttm_col)
test_viz = viz.BaseViz(datasource, form_data)
test_viz.df_metrics_to_num = Mock()
@@ -109,7 +109,7 @@ class BaseVizTestCase(SupersetTestCase):
results.df = pd.DataFrame(data={DTTM_ALIAS: ["1960-01-01 05:00:00"]})
datasource.offset = 0
mock_dttm_col = Mock()
- datasource.get_col = Mock(return_value=mock_dttm_col)
+ datasource.get_column = Mock(return_value=mock_dttm_col)
mock_dttm_col.python_date_format = "epoch_ms"
result = test_viz.get_df(query_obj)
import logging