You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2023/01/20 15:26:16 UTC
[superset] 01/01: re patch sqlatable into exploremixin
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch fix-explore-mixin
in repository https://gitbox.apache.org/repos/asf/superset.git
commit e386bc426c3080a48d9be1b5fafd8e6fbd84df63
Author: hughhhh <hu...@gmail.com>
AuthorDate: Fri Jan 20 17:25:56 2023 +0200
re patch sqlatable into exploremixin
---
superset/connectors/sqla/models.py | 1308 ++++++++++++-----------
superset/models/helpers.py | 349 +++---
superset/models/sql_lab.py | 3 +-
superset/result_set.py | 4 +-
superset/utils/pandas_postprocessing/boxplot.py | 8 +-
superset/utils/pandas_postprocessing/flatten.py | 2 +-
6 files changed, 854 insertions(+), 820 deletions(-)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index c5fd025f4e..b363188b87 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -105,7 +105,12 @@ from superset.jinja_context import (
)
from superset.models.annotations import Annotation
from superset.models.core import Database
-from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult
+from superset.models.helpers import (
+ AuditMixinNullable,
+ CertificationMixin,
+ ExploreMixin,
+ QueryResult,
+)
from superset.sql_parse import ParsedQuery, sanitize_clause
from superset.superset_typing import (
AdhocColumn,
@@ -149,12 +154,13 @@ class SqlaQuery(NamedTuple):
prequeries: List[str]
sqla_query: Select
+from superset.models.helpers import QueryStringExtended
-class QueryStringExtended(NamedTuple):
- applied_template_filters: Optional[List[str]]
- labels_expected: List[str]
- prequeries: List[str]
- sql: str
+# class QueryStringExtended(NamedTuple):
+# applied_template_filters: Optional[List[str]]
+# labels_expected: List[str]
+# prequeries: List[str]
+# sql: str
@dataclass
@@ -534,7 +540,7 @@ def _process_sql_expression(
return expression
-class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
+class SqlaTable(Model, BaseDatasource, ExploreMixin): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""
type = "table"
@@ -980,7 +986,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return self.make_sqla_column_compatible(sqla_metric, label)
- def adhoc_column_to_sqla(
+ def adhoc_column_to_sqla( # type: ignore
self,
col: AdhocColumn,
template_processor: Optional[BaseTemplateProcessor] = None,
@@ -1118,649 +1124,649 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def text(self, clause: str) -> TextClause:
return self.db_engine_spec.get_text_clause(clause)
- def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
- self,
- apply_fetch_values_predicate: bool = False,
- columns: Optional[List[ColumnTyping]] = None,
- extras: Optional[Dict[str, Any]] = None,
- filter: Optional[ # pylint: disable=redefined-builtin
- List[QueryObjectFilterClause]
- ] = None,
- from_dttm: Optional[datetime] = None,
- granularity: Optional[str] = None,
- groupby: Optional[List[Column]] = None,
- inner_from_dttm: Optional[datetime] = None,
- inner_to_dttm: Optional[datetime] = None,
- is_rowcount: bool = False,
- is_timeseries: bool = True,
- metrics: Optional[List[Metric]] = None,
- orderby: Optional[List[OrderBy]] = None,
- order_desc: bool = True,
- to_dttm: Optional[datetime] = None,
- series_columns: Optional[List[Column]] = None,
- series_limit: Optional[int] = None,
- series_limit_metric: Optional[Metric] = None,
- row_limit: Optional[int] = None,
- row_offset: Optional[int] = None,
- timeseries_limit: Optional[int] = None,
- timeseries_limit_metric: Optional[Metric] = None,
- time_shift: Optional[str] = None,
- ) -> SqlaQuery:
- """Querying any sqla table from this common interface"""
- if granularity not in self.dttm_cols and granularity is not None:
- granularity = self.main_dttm_col
-
- extras = extras or {}
- time_grain = extras.get("time_grain_sqla")
-
- template_kwargs = {
- "columns": columns,
- "from_dttm": from_dttm.isoformat() if from_dttm else None,
- "groupby": groupby,
- "metrics": metrics,
- "row_limit": row_limit,
- "row_offset": row_offset,
- "time_column": granularity,
- "time_grain": time_grain,
- "to_dttm": to_dttm.isoformat() if to_dttm else None,
- "table_columns": [col.column_name for col in self.columns],
- "filter": filter,
- }
- columns = columns or []
- groupby = groupby or []
- series_column_names = utils.get_column_names(series_columns or [])
- # deprecated, to be removed in 2.0
- if is_timeseries and timeseries_limit:
- series_limit = timeseries_limit
- series_limit_metric = series_limit_metric or timeseries_limit_metric
- template_kwargs.update(self.template_params_dict)
- extra_cache_keys: List[Any] = []
- template_kwargs["extra_cache_keys"] = extra_cache_keys
- removed_filters: List[str] = []
- applied_template_filters: List[str] = []
- template_kwargs["removed_filters"] = removed_filters
- template_kwargs["applied_filters"] = applied_template_filters
- template_processor = self.get_template_processor(**template_kwargs)
- db_engine_spec = self.db_engine_spec
- prequeries: List[str] = []
- orderby = orderby or []
- need_groupby = bool(metrics is not None or groupby)
- metrics = metrics or []
-
- # For backward compatibility
- if granularity not in self.dttm_cols and granularity is not None:
- granularity = self.main_dttm_col
-
- columns_by_name: Dict[str, TableColumn] = {
- col.column_name: col for col in self.columns
- }
-
- metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
-
- if not granularity and is_timeseries:
- raise QueryObjectValidationError(
- _(
- "Datetime column not provided as part table configuration "
- "and is required by this type of chart"
- )
- )
- if not metrics and not columns and not groupby:
- raise QueryObjectValidationError(_("Empty query?"))
-
- metrics_exprs: List[ColumnElement] = []
- for metric in metrics:
- if utils.is_adhoc_metric(metric):
- assert isinstance(metric, dict)
- metrics_exprs.append(
- self.adhoc_metric_to_sqla(
- metric=metric,
- columns_by_name=columns_by_name,
- template_processor=template_processor,
- )
- )
- elif isinstance(metric, str) and metric in metrics_by_name:
- metrics_exprs.append(
- metrics_by_name[metric].get_sqla_col(
- template_processor=template_processor
- )
- )
- else:
- raise QueryObjectValidationError(
- _("Metric '%(metric)s' does not exist", metric=metric)
- )
-
- if metrics_exprs:
- main_metric_expr = metrics_exprs[0]
- else:
- main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
- main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
-
- # To ensure correct handling of the ORDER BY labeling we need to reference the
- # metric instance if defined in the SELECT clause.
- # use the key of the ColumnClause for the expected label
- metrics_exprs_by_label = {m.key: m for m in metrics_exprs}
- metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
-
- # Since orderby may use adhoc metrics, too; we need to process them first
- orderby_exprs: List[ColumnElement] = []
- for orig_col, ascending in orderby:
- col: Union[AdhocMetric, ColumnElement] = orig_col
- if isinstance(col, dict):
- col = cast(AdhocMetric, col)
- if col.get("sqlExpression"):
- col["sqlExpression"] = _process_sql_expression(
- expression=col["sqlExpression"],
- database_id=self.database_id,
- schema=self.schema,
- template_processor=template_processor,
- )
- if utils.is_adhoc_metric(col):
- # add adhoc sort by column to columns_by_name if not exists
- col = self.adhoc_metric_to_sqla(col, columns_by_name)
- # if the adhoc metric has been defined before
- # use the existing instance.
- col = metrics_exprs_by_expr.get(str(col), col)
- need_groupby = True
- elif col in columns_by_name:
- col = columns_by_name[col].get_sqla_col(
- template_processor=template_processor
- )
- elif col in metrics_exprs_by_label:
- col = metrics_exprs_by_label[col]
- need_groupby = True
- elif col in metrics_by_name:
- col = metrics_by_name[col].get_sqla_col(
- template_processor=template_processor
- )
- need_groupby = True
-
- if isinstance(col, ColumnElement):
- orderby_exprs.append(col)
- else:
- # Could not convert a column reference to valid ColumnElement
- raise QueryObjectValidationError(
- _("Unknown column used in orderby: %(col)s", col=orig_col)
- )
-
- select_exprs: List[Union[Column, Label]] = []
- groupby_all_columns = {}
- groupby_series_columns = {}
-
- # filter out the pseudo column __timestamp from columns
- columns = [col for col in columns if col != utils.DTTM_ALIAS]
- dttm_col = columns_by_name.get(granularity) if granularity else None
-
- if need_groupby:
- # dedup columns while preserving order
- columns = groupby or columns
- for selected in columns:
- if isinstance(selected, str):
- # if groupby field/expr equals granularity field/expr
- if selected == granularity:
- table_col = columns_by_name[selected]
- outer = table_col.get_timestamp_expression(
- time_grain=time_grain,
- label=selected,
- template_processor=template_processor,
- )
- # if groupby field equals a selected column
- elif selected in columns_by_name:
- outer = columns_by_name[selected].get_sqla_col(
- template_processor=template_processor
- )
- else:
- selected = validate_adhoc_subquery(
- selected,
- self.database_id,
- self.schema,
- )
- outer = literal_column(f"({selected})")
- outer = self.make_sqla_column_compatible(outer, selected)
- else:
- outer = self.adhoc_column_to_sqla(
- col=selected, template_processor=template_processor
- )
- groupby_all_columns[outer.name] = outer
- if (
- is_timeseries and not series_column_names
- ) or outer.name in series_column_names:
- groupby_series_columns[outer.name] = outer
- select_exprs.append(outer)
- elif columns:
- for selected in columns:
- if is_adhoc_column(selected):
- _sql = selected["sqlExpression"]
- _column_label = selected["label"]
- elif isinstance(selected, str):
- _sql = selected
- _column_label = selected
-
- selected = validate_adhoc_subquery(
- _sql,
- self.database_id,
- self.schema,
- )
- select_exprs.append(
- columns_by_name[selected].get_sqla_col(
- template_processor=template_processor
- )
- if isinstance(selected, str) and selected in columns_by_name
- else self.make_sqla_column_compatible(
- literal_column(selected), _column_label
- )
- )
- metrics_exprs = []
-
- if granularity:
- if granularity not in columns_by_name or not dttm_col:
- raise QueryObjectValidationError(
- _(
- 'Time column "%(col)s" does not exist in dataset',
- col=granularity,
- )
- )
- time_filters = []
-
- if is_timeseries:
- timestamp = dttm_col.get_timestamp_expression(
- time_grain=time_grain, template_processor=template_processor
- )
- # always put timestamp as the first column
- select_exprs.insert(0, timestamp)
- groupby_all_columns[timestamp.name] = timestamp
-
- # Use main dttm column to support index with secondary dttm columns.
- if (
- db_engine_spec.time_secondary_columns
- and self.main_dttm_col in self.dttm_cols
- and self.main_dttm_col != dttm_col.column_name
- ):
- time_filters.append(
- columns_by_name[self.main_dttm_col].get_time_filter(
- start_dttm=from_dttm,
- end_dttm=to_dttm,
- template_processor=template_processor,
- )
- )
- time_filters.append(
- dttm_col.get_time_filter(
- start_dttm=from_dttm,
- end_dttm=to_dttm,
- template_processor=template_processor,
- )
- )
-
- # Always remove duplicates by column name, as sometimes `metrics_exprs`
- # can have the same name as a groupby column (e.g. when users use
- # raw columns as custom SQL adhoc metric).
- select_exprs = remove_duplicates(
- select_exprs + metrics_exprs, key=lambda x: x.name
- )
-
- # Expected output columns
- labels_expected = [c.key for c in select_exprs]
-
- # Order by columns are "hidden" columns, some databases require them
- # always be present in SELECT if an aggregation function is used
- if not db_engine_spec.allows_hidden_ordeby_agg:
- select_exprs = remove_duplicates(select_exprs + orderby_exprs)
-
- qry = sa.select(select_exprs)
-
- tbl, cte = self.get_from_clause(template_processor)
-
- if groupby_all_columns:
- qry = qry.group_by(*groupby_all_columns.values())
-
- where_clause_and = []
- having_clause_and = []
-
- for flt in filter: # type: ignore
- if not all(flt.get(s) for s in ["col", "op"]):
- continue
- flt_col = flt["col"]
- val = flt.get("val")
- op = flt["op"].upper()
- col_obj: Optional[TableColumn] = None
- sqla_col: Optional[Column] = None
- if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
- col_obj = dttm_col
- elif is_adhoc_column(flt_col):
- sqla_col = self.adhoc_column_to_sqla(flt_col)
- else:
- col_obj = columns_by_name.get(flt_col)
- filter_grain = flt.get("grain")
-
- if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
- if get_column_name(flt_col) in removed_filters:
- # Skip generating SQLA filter when the jinja template handles it.
- continue
-
- if col_obj or sqla_col is not None:
- if sqla_col is not None:
- pass
- elif col_obj and filter_grain:
- sqla_col = col_obj.get_timestamp_expression(
- time_grain=filter_grain, template_processor=template_processor
- )
- elif col_obj:
- sqla_col = col_obj.get_sqla_col(
- template_processor=template_processor
- )
- col_type = col_obj.type if col_obj else None
- col_spec = db_engine_spec.get_column_spec(
- native_type=col_type,
- db_extra=self.database.get_extra(),
- )
- is_list_target = op in (
- utils.FilterOperator.IN.value,
- utils.FilterOperator.NOT_IN.value,
- )
-
- col_advanced_data_type = col_obj.advanced_data_type if col_obj else ""
-
- if col_spec and not col_advanced_data_type:
- target_generic_type = col_spec.generic_type
- else:
- target_generic_type = GenericDataType.STRING
- eq = self.filter_values_handler(
- values=val,
- operator=op,
- target_generic_type=target_generic_type,
- target_native_type=col_type,
- is_list_target=is_list_target,
- db_engine_spec=db_engine_spec,
- db_extra=self.database.get_extra(),
- )
- if (
- col_advanced_data_type != ""
- and feature_flag_manager.is_feature_enabled(
- "ENABLE_ADVANCED_DATA_TYPES"
- )
- and col_advanced_data_type in ADVANCED_DATA_TYPES
- ):
- values = eq if is_list_target else [eq] # type: ignore
- bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
- col_advanced_data_type
- ].translate_type(
- {
- "type": col_advanced_data_type,
- "values": values,
- }
- )
- if bus_resp["error_message"]:
- raise AdvancedDataTypeResponseError(
- _(bus_resp["error_message"])
- )
-
- where_clause_and.append(
- ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
- sqla_col, op, bus_resp["values"]
- )
- )
- elif is_list_target:
- assert isinstance(eq, (tuple, list))
- if len(eq) == 0:
- raise QueryObjectValidationError(
- _("Filter value list cannot be empty")
- )
- if len(eq) > len(
- eq_without_none := [x for x in eq if x is not None]
- ):
- is_null_cond = sqla_col.is_(None)
- if eq:
- cond = or_(is_null_cond, sqla_col.in_(eq_without_none))
- else:
- cond = is_null_cond
- else:
- cond = sqla_col.in_(eq)
- if op == utils.FilterOperator.NOT_IN.value:
- cond = ~cond
- where_clause_and.append(cond)
- elif op == utils.FilterOperator.IS_NULL.value:
- where_clause_and.append(sqla_col.is_(None))
- elif op == utils.FilterOperator.IS_NOT_NULL.value:
- where_clause_and.append(sqla_col.isnot(None))
- elif op == utils.FilterOperator.IS_TRUE.value:
- where_clause_and.append(sqla_col.is_(True))
- elif op == utils.FilterOperator.IS_FALSE.value:
- where_clause_and.append(sqla_col.is_(False))
- else:
- if (
- op
- not in {
- utils.FilterOperator.EQUALS.value,
- utils.FilterOperator.NOT_EQUALS.value,
- }
- and eq is None
- ):
- raise QueryObjectValidationError(
- _(
- "Must specify a value for filters "
- "with comparison operators"
- )
- )
- if op == utils.FilterOperator.EQUALS.value:
- where_clause_and.append(sqla_col == eq)
- elif op == utils.FilterOperator.NOT_EQUALS.value:
- where_clause_and.append(sqla_col != eq)
- elif op == utils.FilterOperator.GREATER_THAN.value:
- where_clause_and.append(sqla_col > eq)
- elif op == utils.FilterOperator.LESS_THAN.value:
- where_clause_and.append(sqla_col < eq)
- elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
- where_clause_and.append(sqla_col >= eq)
- elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
- where_clause_and.append(sqla_col <= eq)
- elif op == utils.FilterOperator.LIKE.value:
- where_clause_and.append(sqla_col.like(eq))
- elif op == utils.FilterOperator.ILIKE.value:
- where_clause_and.append(sqla_col.ilike(eq))
- elif (
- op == utils.FilterOperator.TEMPORAL_RANGE.value
- and isinstance(eq, str)
- and col_obj is not None
- ):
- _since, _until = get_since_until_from_time_range(
- time_range=eq,
- time_shift=time_shift,
- extras=extras,
- )
- where_clause_and.append(
- col_obj.get_time_filter(
- start_dttm=_since,
- end_dttm=_until,
- label=sqla_col.key,
- template_processor=template_processor,
- )
- )
- else:
- raise QueryObjectValidationError(
- _("Invalid filter operation type: %(op)s", op=op)
- )
- where_clause_and += self.get_sqla_row_level_filters(template_processor)
- if extras:
- where = extras.get("where")
- if where:
- try:
- where = template_processor.process_template(f"({where})")
- except TemplateError as ex:
- raise QueryObjectValidationError(
- _(
- "Error in jinja expression in WHERE clause: %(msg)s",
- msg=ex.message,
- )
- ) from ex
- where = _process_sql_expression(
- expression=where,
- database_id=self.database_id,
- schema=self.schema,
- )
- where_clause_and += [self.text(where)]
- having = extras.get("having")
- if having:
- try:
- having = template_processor.process_template(f"({having})")
- except TemplateError as ex:
- raise QueryObjectValidationError(
- _(
- "Error in jinja expression in HAVING clause: %(msg)s",
- msg=ex.message,
- )
- ) from ex
- having = _process_sql_expression(
- expression=having,
- database_id=self.database_id,
- schema=self.schema,
- )
- having_clause_and += [self.text(having)]
-
- if apply_fetch_values_predicate and self.fetch_values_predicate:
- qry = qry.where(
- self.get_fetch_values_predicate(template_processor=template_processor)
- )
- if granularity:
- qry = qry.where(and_(*(time_filters + where_clause_and)))
- else:
- qry = qry.where(and_(*where_clause_and))
- qry = qry.having(and_(*having_clause_and))
-
- self.make_orderby_compatible(select_exprs, orderby_exprs)
-
- for col, (orig_col, ascending) in zip(orderby_exprs, orderby):
- if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
- # if engine does not allow using SELECT alias in ORDER BY
- # revert to the underlying column
- col = col.element
-
- if (
- db_engine_spec.allows_alias_in_select
- and db_engine_spec.allows_hidden_cc_in_orderby
- and col.name in [select_col.name for select_col in select_exprs]
- ):
- col = literal_column(col.name)
- direction = asc if ascending else desc
- qry = qry.order_by(direction(col))
-
- if row_limit:
- qry = qry.limit(row_limit)
- if row_offset:
- qry = qry.offset(row_offset)
-
- if series_limit and groupby_series_columns:
- if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries:
- # some sql dialects require for order by expressions
- # to also be in the select clause -- others, e.g. vertica,
- # require a unique inner alias
- inner_main_metric_expr = self.make_sqla_column_compatible(
- main_metric_expr, "mme_inner__"
- )
- inner_groupby_exprs = []
- inner_select_exprs = []
- for gby_name, gby_obj in groupby_series_columns.items():
- label = get_column_name(gby_name)
- inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
- inner_groupby_exprs.append(inner)
- inner_select_exprs.append(inner)
-
- inner_select_exprs += [inner_main_metric_expr]
- subq = select(inner_select_exprs).select_from(tbl)
- inner_time_filter = []
-
- if dttm_col and not db_engine_spec.time_groupby_inline:
- inner_time_filter = [
- dttm_col.get_time_filter(
- start_dttm=inner_from_dttm or from_dttm,
- end_dttm=inner_to_dttm or to_dttm,
- template_processor=template_processor,
- )
- ]
- subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
- subq = subq.group_by(*inner_groupby_exprs)
-
- ob = inner_main_metric_expr
- if series_limit_metric:
- ob = self._get_series_orderby(
- series_limit_metric=series_limit_metric,
- metrics_by_name=metrics_by_name,
- columns_by_name=columns_by_name,
- template_processor=template_processor,
- )
- direction = desc if order_desc else asc
- subq = subq.order_by(direction(ob))
- subq = subq.limit(series_limit)
-
- on_clause = []
- for gby_name, gby_obj in groupby_series_columns.items():
- # in this case the column name, not the alias, needs to be
- # conditionally mutated, as it refers to the column alias in
- # the inner query
- col_name = db_engine_spec.make_label_compatible(gby_name + "__")
- on_clause.append(gby_obj == column(col_name))
-
- tbl = tbl.join(subq.alias(), and_(*on_clause))
- else:
- if series_limit_metric:
- orderby = [
- (
- self._get_series_orderby(
- series_limit_metric=series_limit_metric,
- metrics_by_name=metrics_by_name,
- columns_by_name=columns_by_name,
- template_processor=template_processor,
- ),
- not order_desc,
- )
- ]
-
- # run prequery to get top groups
- prequery_obj = {
- "is_timeseries": False,
- "row_limit": series_limit,
- "metrics": metrics,
- "granularity": granularity,
- "groupby": groupby,
- "from_dttm": inner_from_dttm or from_dttm,
- "to_dttm": inner_to_dttm or to_dttm,
- "filter": filter,
- "orderby": orderby,
- "extras": extras,
- "columns": columns,
- "order_desc": True,
- }
-
- result = self.query(prequery_obj)
- prequeries.append(result.query)
- dimensions = [
- c
- for c in result.df.columns
- if c not in metrics and c in groupby_series_columns
- ]
- top_groups = self._get_top_groups(
- result.df, dimensions, groupby_series_columns, columns_by_name
- )
- qry = qry.where(top_groups)
-
- qry = qry.select_from(tbl)
-
- if is_rowcount:
- if not db_engine_spec.allows_subqueries:
- raise QueryObjectValidationError(
- _("Database does not support subqueries")
- )
- label = "rowcount"
- col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
- qry = select([col]).select_from(qry.alias("rowcount_qry"))
- labels_expected = [label]
-
- return SqlaQuery(
- applied_template_filters=applied_template_filters,
- cte=cte,
- extra_cache_keys=extra_cache_keys,
- labels_expected=labels_expected,
- sqla_query=qry,
- prequeries=prequeries,
- )
+ # def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
+ # self,
+ # apply_fetch_values_predicate: bool = False,
+ # columns: Optional[List[ColumnTyping]] = None,
+ # extras: Optional[Dict[str, Any]] = None,
+ # filter: Optional[ # pylint: disable=redefined-builtin
+ # List[QueryObjectFilterClause]
+ # ] = None,
+ # from_dttm: Optional[datetime] = None,
+ # granularity: Optional[str] = None,
+ # groupby: Optional[List[Column]] = None,
+ # inner_from_dttm: Optional[datetime] = None,
+ # inner_to_dttm: Optional[datetime] = None,
+ # is_rowcount: bool = False,
+ # is_timeseries: bool = True,
+ # metrics: Optional[List[Metric]] = None,
+ # orderby: Optional[List[OrderBy]] = None,
+ # order_desc: bool = True,
+ # to_dttm: Optional[datetime] = None,
+ # series_columns: Optional[List[Column]] = None,
+ # series_limit: Optional[int] = None,
+ # series_limit_metric: Optional[Metric] = None,
+ # row_limit: Optional[int] = None,
+ # row_offset: Optional[int] = None,
+ # timeseries_limit: Optional[int] = None,
+ # timeseries_limit_metric: Optional[Metric] = None,
+ # time_shift: Optional[str] = None,
+ # ) -> SqlaQuery:
+ # """Querying any sqla table from this common interface"""
+ # if granularity not in self.dttm_cols and granularity is not None:
+ # granularity = self.main_dttm_col
+
+ # extras = extras or {}
+ # time_grain = extras.get("time_grain_sqla")
+
+ # template_kwargs = {
+ # "columns": columns,
+ # "from_dttm": from_dttm.isoformat() if from_dttm else None,
+ # "groupby": groupby,
+ # "metrics": metrics,
+ # "row_limit": row_limit,
+ # "row_offset": row_offset,
+ # "time_column": granularity,
+ # "time_grain": time_grain,
+ # "to_dttm": to_dttm.isoformat() if to_dttm else None,
+ # "table_columns": [col.column_name for col in self.columns],
+ # "filter": filter,
+ # }
+ # columns = columns or []
+ # groupby = groupby or []
+ # series_column_names = utils.get_column_names(series_columns or [])
+ # # deprecated, to be removed in 2.0
+ # if is_timeseries and timeseries_limit:
+ # series_limit = timeseries_limit
+ # series_limit_metric = series_limit_metric or timeseries_limit_metric
+ # template_kwargs.update(self.template_params_dict)
+ # extra_cache_keys: List[Any] = []
+ # template_kwargs["extra_cache_keys"] = extra_cache_keys
+ # removed_filters: List[str] = []
+ # applied_template_filters: List[str] = []
+ # template_kwargs["removed_filters"] = removed_filters
+ # template_kwargs["applied_filters"] = applied_template_filters
+ # template_processor = self.get_template_processor(**template_kwargs)
+ # db_engine_spec = self.db_engine_spec
+ # prequeries: List[str] = []
+ # orderby = orderby or []
+ # need_groupby = bool(metrics is not None or groupby)
+ # metrics = metrics or []
+
+ # # For backward compatibility
+ # if granularity not in self.dttm_cols and granularity is not None:
+ # granularity = self.main_dttm_col
+
+ # columns_by_name: Dict[str, TableColumn] = {
+ # col.column_name: col for col in self.columns
+ # }
+
+ # metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
+
+ # if not granularity and is_timeseries:
+ # raise QueryObjectValidationError(
+ # _(
+ # "Datetime column not provided as part table configuration "
+ # "and is required by this type of chart"
+ # )
+ # )
+ # if not metrics and not columns and not groupby:
+ # raise QueryObjectValidationError(_("Empty query?"))
+
+ # metrics_exprs: List[ColumnElement] = []
+ # for metric in metrics:
+ # if utils.is_adhoc_metric(metric):
+ # assert isinstance(metric, dict)
+ # metrics_exprs.append(
+ # self.adhoc_metric_to_sqla(
+ # metric=metric,
+ # columns_by_name=columns_by_name,
+ # template_processor=template_processor,
+ # )
+ # )
+ # elif isinstance(metric, str) and metric in metrics_by_name:
+ # metrics_exprs.append(
+ # metrics_by_name[metric].get_sqla_col(
+ # template_processor=template_processor
+ # )
+ # )
+ # else:
+ # raise QueryObjectValidationError(
+ # _("Metric '%(metric)s' does not exist", metric=metric)
+ # )
+
+ # if metrics_exprs:
+ # main_metric_expr = metrics_exprs[0]
+ # else:
+ # main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
+ # main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
+
+ # # To ensure correct handling of the ORDER BY labeling we need to reference the
+ # # metric instance if defined in the SELECT clause.
+ # # use the key of the ColumnClause for the expected label
+ # metrics_exprs_by_label = {m.key: m for m in metrics_exprs}
+ # metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
+
+ # # Since orderby may use adhoc metrics, too; we need to process them first
+ # orderby_exprs: List[ColumnElement] = []
+ # for orig_col, ascending in orderby:
+ # col: Union[AdhocMetric, ColumnElement] = orig_col
+ # if isinstance(col, dict):
+ # col = cast(AdhocMetric, col)
+ # if col.get("sqlExpression"):
+ # col["sqlExpression"] = _process_sql_expression(
+ # expression=col["sqlExpression"],
+ # database_id=self.database_id,
+ # schema=self.schema,
+ # template_processor=template_processor,
+ # )
+ # if utils.is_adhoc_metric(col):
+ # # add adhoc sort by column to columns_by_name if not exists
+ # col = self.adhoc_metric_to_sqla(col, columns_by_name)
+ # # if the adhoc metric has been defined before
+ # # use the existing instance.
+ # col = metrics_exprs_by_expr.get(str(col), col)
+ # need_groupby = True
+ # elif col in columns_by_name:
+ # col = columns_by_name[col].get_sqla_col(
+ # template_processor=template_processor
+ # )
+ # elif col in metrics_exprs_by_label:
+ # col = metrics_exprs_by_label[col]
+ # need_groupby = True
+ # elif col in metrics_by_name:
+ # col = metrics_by_name[col].get_sqla_col(
+ # template_processor=template_processor
+ # )
+ # need_groupby = True
+
+ # if isinstance(col, ColumnElement):
+ # orderby_exprs.append(col)
+ # else:
+ # # Could not convert a column reference to valid ColumnElement
+ # raise QueryObjectValidationError(
+ # _("Unknown column used in orderby: %(col)s", col=orig_col)
+ # )
+
+ # select_exprs: List[Union[Column, Label]] = []
+ # groupby_all_columns = {}
+ # groupby_series_columns = {}
+
+ # # filter out the pseudo column __timestamp from columns
+ # columns = [col for col in columns if col != utils.DTTM_ALIAS]
+ # dttm_col = columns_by_name.get(granularity) if granularity else None
+
+ # if need_groupby:
+ # # dedup columns while preserving order
+ # columns = groupby or columns
+ # for selected in columns:
+ # if isinstance(selected, str):
+ # # if groupby field/expr equals granularity field/expr
+ # if selected == granularity:
+ # table_col = columns_by_name[selected]
+ # outer = table_col.get_timestamp_expression(
+ # time_grain=time_grain,
+ # label=selected,
+ # template_processor=template_processor,
+ # )
+ # # if groupby field equals a selected column
+ # elif selected in columns_by_name:
+ # outer = columns_by_name[selected].get_sqla_col(
+ # template_processor=template_processor
+ # )
+ # else:
+ # selected = validate_adhoc_subquery(
+ # selected,
+ # self.database_id,
+ # self.schema,
+ # )
+ # outer = literal_column(f"({selected})")
+ # outer = self.make_sqla_column_compatible(outer, selected)
+ # else:
+ # outer = self.adhoc_column_to_sqla(
+ # col=selected, template_processor=template_processor
+ # )
+ # groupby_all_columns[outer.name] = outer
+ # if (
+ # is_timeseries and not series_column_names
+ # ) or outer.name in series_column_names:
+ # groupby_series_columns[outer.name] = outer
+ # select_exprs.append(outer)
+ # elif columns:
+ # for selected in columns:
+ # if is_adhoc_column(selected):
+ # _sql = selected["sqlExpression"]
+ # _column_label = selected["label"]
+ # elif isinstance(selected, str):
+ # _sql = selected
+ # _column_label = selected
+
+ # selected = validate_adhoc_subquery(
+ # _sql,
+ # self.database_id,
+ # self.schema,
+ # )
+ # select_exprs.append(
+ # columns_by_name[selected].get_sqla_col(
+ # template_processor=template_processor
+ # )
+ # if isinstance(selected, str) and selected in columns_by_name
+ # else self.make_sqla_column_compatible(
+ # literal_column(selected), _column_label
+ # )
+ # )
+ # metrics_exprs = []
+
+ # if granularity:
+ # if granularity not in columns_by_name or not dttm_col:
+ # raise QueryObjectValidationError(
+ # _(
+ # 'Time column "%(col)s" does not exist in dataset',
+ # col=granularity,
+ # )
+ # )
+ # time_filters = []
+
+ # if is_timeseries:
+ # timestamp = dttm_col.get_timestamp_expression(
+ # time_grain=time_grain, template_processor=template_processor
+ # )
+ # # always put timestamp as the first column
+ # select_exprs.insert(0, timestamp)
+ # groupby_all_columns[timestamp.name] = timestamp
+
+ # # Use main dttm column to support index with secondary dttm columns.
+ # if (
+ # db_engine_spec.time_secondary_columns
+ # and self.main_dttm_col in self.dttm_cols
+ # and self.main_dttm_col != dttm_col.column_name
+ # ):
+ # time_filters.append(
+ # columns_by_name[self.main_dttm_col].get_time_filter(
+ # start_dttm=from_dttm,
+ # end_dttm=to_dttm,
+ # template_processor=template_processor,
+ # )
+ # )
+ # time_filters.append(
+ # dttm_col.get_time_filter(
+ # start_dttm=from_dttm,
+ # end_dttm=to_dttm,
+ # template_processor=template_processor,
+ # )
+ # )
+
+ # # Always remove duplicates by column name, as sometimes `metrics_exprs`
+ # # can have the same name as a groupby column (e.g. when users use
+ # # raw columns as custom SQL adhoc metric).
+ # select_exprs = remove_duplicates(
+ # select_exprs + metrics_exprs, key=lambda x: x.name
+ # )
+
+ # # Expected output columns
+ # labels_expected = [c.key for c in select_exprs]
+
+ # # Order by columns are "hidden" columns, some databases require them
+ # # always be present in SELECT if an aggregation function is used
+ # if not db_engine_spec.allows_hidden_ordeby_agg:
+ # select_exprs = remove_duplicates(select_exprs + orderby_exprs)
+
+ # qry = sa.select(select_exprs)
+
+ # tbl, cte = self.get_from_clause(template_processor)
+
+ # if groupby_all_columns:
+ # qry = qry.group_by(*groupby_all_columns.values())
+
+ # where_clause_and = []
+ # having_clause_and = []
+
+ # for flt in filter: # type: ignore
+ # if not all(flt.get(s) for s in ["col", "op"]):
+ # continue
+ # flt_col = flt["col"]
+ # val = flt.get("val")
+ # op = flt["op"].upper()
+ # col_obj: Optional[TableColumn] = None
+ # sqla_col: Optional[Column] = None
+ # if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
+ # col_obj = dttm_col
+ # elif is_adhoc_column(flt_col):
+ # sqla_col = self.adhoc_column_to_sqla(flt_col)
+ # else:
+ # col_obj = columns_by_name.get(flt_col)
+ # filter_grain = flt.get("grain")
+
+ # if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
+ # if get_column_name(flt_col) in removed_filters:
+ # # Skip generating SQLA filter when the jinja template handles it.
+ # continue
+
+ # if col_obj or sqla_col is not None:
+ # if sqla_col is not None:
+ # pass
+ # elif col_obj and filter_grain:
+ # sqla_col = col_obj.get_timestamp_expression(
+ # time_grain=filter_grain, template_processor=template_processor
+ # )
+ # elif col_obj:
+ # sqla_col = col_obj.get_sqla_col(
+ # template_processor=template_processor
+ # )
+ # col_type = col_obj.type if col_obj else None
+ # col_spec = db_engine_spec.get_column_spec(
+ # native_type=col_type,
+ # db_extra=self.database.get_extra(),
+ # )
+ # is_list_target = op in (
+ # utils.FilterOperator.IN.value,
+ # utils.FilterOperator.NOT_IN.value,
+ # )
+
+ # col_advanced_data_type = col_obj.advanced_data_type if col_obj else ""
+
+ # if col_spec and not col_advanced_data_type:
+ # target_generic_type = col_spec.generic_type
+ # else:
+ # target_generic_type = GenericDataType.STRING
+ # eq = self.filter_values_handler(
+ # values=val,
+ # operator=op,
+ # target_generic_type=target_generic_type,
+ # target_native_type=col_type,
+ # is_list_target=is_list_target,
+ # db_engine_spec=db_engine_spec,
+ # db_extra=self.database.get_extra(),
+ # )
+ # if (
+ # col_advanced_data_type != ""
+ # and feature_flag_manager.is_feature_enabled(
+ # "ENABLE_ADVANCED_DATA_TYPES"
+ # )
+ # and col_advanced_data_type in ADVANCED_DATA_TYPES
+ # ):
+ # values = eq if is_list_target else [eq] # type: ignore
+ # bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
+ # col_advanced_data_type
+ # ].translate_type(
+ # {
+ # "type": col_advanced_data_type,
+ # "values": values,
+ # }
+ # )
+ # if bus_resp["error_message"]:
+ # raise AdvancedDataTypeResponseError(
+ # _(bus_resp["error_message"])
+ # )
+
+ # where_clause_and.append(
+ # ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
+ # sqla_col, op, bus_resp["values"]
+ # )
+ # )
+ # elif is_list_target:
+ # assert isinstance(eq, (tuple, list))
+ # if len(eq) == 0:
+ # raise QueryObjectValidationError(
+ # _("Filter value list cannot be empty")
+ # )
+ # if len(eq) > len(
+ # eq_without_none := [x for x in eq if x is not None]
+ # ):
+ # is_null_cond = sqla_col.is_(None)
+ # if eq:
+ # cond = or_(is_null_cond, sqla_col.in_(eq_without_none))
+ # else:
+ # cond = is_null_cond
+ # else:
+ # cond = sqla_col.in_(eq)
+ # if op == utils.FilterOperator.NOT_IN.value:
+ # cond = ~cond
+ # where_clause_and.append(cond)
+ # elif op == utils.FilterOperator.IS_NULL.value:
+ # where_clause_and.append(sqla_col.is_(None))
+ # elif op == utils.FilterOperator.IS_NOT_NULL.value:
+ # where_clause_and.append(sqla_col.isnot(None))
+ # elif op == utils.FilterOperator.IS_TRUE.value:
+ # where_clause_and.append(sqla_col.is_(True))
+ # elif op == utils.FilterOperator.IS_FALSE.value:
+ # where_clause_and.append(sqla_col.is_(False))
+ # else:
+ # if (
+ # op
+ # not in {
+ # utils.FilterOperator.EQUALS.value,
+ # utils.FilterOperator.NOT_EQUALS.value,
+ # }
+ # and eq is None
+ # ):
+ # raise QueryObjectValidationError(
+ # _(
+ # "Must specify a value for filters "
+ # "with comparison operators"
+ # )
+ # )
+ # if op == utils.FilterOperator.EQUALS.value:
+ # where_clause_and.append(sqla_col == eq)
+ # elif op == utils.FilterOperator.NOT_EQUALS.value:
+ # where_clause_and.append(sqla_col != eq)
+ # elif op == utils.FilterOperator.GREATER_THAN.value:
+ # where_clause_and.append(sqla_col > eq)
+ # elif op == utils.FilterOperator.LESS_THAN.value:
+ # where_clause_and.append(sqla_col < eq)
+ # elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
+ # where_clause_and.append(sqla_col >= eq)
+ # elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
+ # where_clause_and.append(sqla_col <= eq)
+ # elif op == utils.FilterOperator.LIKE.value:
+ # where_clause_and.append(sqla_col.like(eq))
+ # elif op == utils.FilterOperator.ILIKE.value:
+ # where_clause_and.append(sqla_col.ilike(eq))
+ # elif (
+ # op == utils.FilterOperator.TEMPORAL_RANGE.value
+ # and isinstance(eq, str)
+ # and col_obj is not None
+ # ):
+ # _since, _until = get_since_until_from_time_range(
+ # time_range=eq,
+ # time_shift=time_shift,
+ # extras=extras,
+ # )
+ # where_clause_and.append(
+ # col_obj.get_time_filter(
+ # start_dttm=_since,
+ # end_dttm=_until,
+ # label=sqla_col.key,
+ # template_processor=template_processor,
+ # )
+ # )
+ # else:
+ # raise QueryObjectValidationError(
+ # _("Invalid filter operation type: %(op)s", op=op)
+ # )
+ # where_clause_and += self.get_sqla_row_level_filters(template_processor)
+ # if extras:
+ # where = extras.get("where")
+ # if where:
+ # try:
+ # where = template_processor.process_template(f"({where})")
+ # except TemplateError as ex:
+ # raise QueryObjectValidationError(
+ # _(
+ # "Error in jinja expression in WHERE clause: %(msg)s",
+ # msg=ex.message,
+ # )
+ # ) from ex
+ # where = _process_sql_expression(
+ # expression=where,
+ # database_id=self.database_id,
+ # schema=self.schema,
+ # )
+ # where_clause_and += [self.text(where)]
+ # having = extras.get("having")
+ # if having:
+ # try:
+ # having = template_processor.process_template(f"({having})")
+ # except TemplateError as ex:
+ # raise QueryObjectValidationError(
+ # _(
+ # "Error in jinja expression in HAVING clause: %(msg)s",
+ # msg=ex.message,
+ # )
+ # ) from ex
+ # having = _process_sql_expression(
+ # expression=having,
+ # database_id=self.database_id,
+ # schema=self.schema,
+ # )
+ # having_clause_and += [self.text(having)]
+
+ # if apply_fetch_values_predicate and self.fetch_values_predicate:
+ # qry = qry.where(
+ # self.get_fetch_values_predicate(template_processor=template_processor)
+ # )
+ # if granularity:
+ # qry = qry.where(and_(*(time_filters + where_clause_and)))
+ # else:
+ # qry = qry.where(and_(*where_clause_and))
+ # qry = qry.having(and_(*having_clause_and))
+
+ # self.make_orderby_compatible(select_exprs, orderby_exprs)
+
+ # for col, (orig_col, ascending) in zip(orderby_exprs, orderby):
+ # if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
+ # # if engine does not allow using SELECT alias in ORDER BY
+ # # revert to the underlying column
+ # col = col.element
+
+ # if (
+ # db_engine_spec.allows_alias_in_select
+ # and db_engine_spec.allows_hidden_cc_in_orderby
+ # and col.name in [select_col.name for select_col in select_exprs]
+ # ):
+ # col = literal_column(col.name)
+ # direction = asc if ascending else desc
+ # qry = qry.order_by(direction(col))
+
+ # if row_limit:
+ # qry = qry.limit(row_limit)
+ # if row_offset:
+ # qry = qry.offset(row_offset)
+
+ # if series_limit and groupby_series_columns:
+ # if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries:
+ # # some sql dialects require for order by expressions
+ # # to also be in the select clause -- others, e.g. vertica,
+ # # require a unique inner alias
+ # inner_main_metric_expr = self.make_sqla_column_compatible(
+ # main_metric_expr, "mme_inner__"
+ # )
+ # inner_groupby_exprs = []
+ # inner_select_exprs = []
+ # for gby_name, gby_obj in groupby_series_columns.items():
+ # label = get_column_name(gby_name)
+ # inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
+ # inner_groupby_exprs.append(inner)
+ # inner_select_exprs.append(inner)
+
+ # inner_select_exprs += [inner_main_metric_expr]
+ # subq = select(inner_select_exprs).select_from(tbl)
+ # inner_time_filter = []
+
+ # if dttm_col and not db_engine_spec.time_groupby_inline:
+ # inner_time_filter = [
+ # dttm_col.get_time_filter(
+ # start_dttm=inner_from_dttm or from_dttm,
+ # end_dttm=inner_to_dttm or to_dttm,
+ # template_processor=template_processor,
+ # )
+ # ]
+ # subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
+ # subq = subq.group_by(*inner_groupby_exprs)
+
+ # ob = inner_main_metric_expr
+ # if series_limit_metric:
+ # ob = self._get_series_orderby(
+ # series_limit_metric=series_limit_metric,
+ # metrics_by_name=metrics_by_name,
+ # columns_by_name=columns_by_name,
+ # template_processor=template_processor,
+ # )
+ # direction = desc if order_desc else asc
+ # subq = subq.order_by(direction(ob))
+ # subq = subq.limit(series_limit)
+
+ # on_clause = []
+ # for gby_name, gby_obj in groupby_series_columns.items():
+ # # in this case the column name, not the alias, needs to be
+ # # conditionally mutated, as it refers to the column alias in
+ # # the inner query
+ # col_name = db_engine_spec.make_label_compatible(gby_name + "__")
+ # on_clause.append(gby_obj == column(col_name))
+
+ # tbl = tbl.join(subq.alias(), and_(*on_clause))
+ # else:
+ # if series_limit_metric:
+ # orderby = [
+ # (
+ # self._get_series_orderby(
+ # series_limit_metric=series_limit_metric,
+ # metrics_by_name=metrics_by_name,
+ # columns_by_name=columns_by_name,
+ # template_processor=template_processor,
+ # ),
+ # not order_desc,
+ # )
+ # ]
+
+ # # run prequery to get top groups
+ # prequery_obj = {
+ # "is_timeseries": False,
+ # "row_limit": series_limit,
+ # "metrics": metrics,
+ # "granularity": granularity,
+ # "groupby": groupby,
+ # "from_dttm": inner_from_dttm or from_dttm,
+ # "to_dttm": inner_to_dttm or to_dttm,
+ # "filter": filter,
+ # "orderby": orderby,
+ # "extras": extras,
+ # "columns": columns,
+ # "order_desc": True,
+ # }
+
+ # result = self.query(prequery_obj)
+ # prequeries.append(result.query)
+ # dimensions = [
+ # c
+ # for c in result.df.columns
+ # if c not in metrics and c in groupby_series_columns
+ # ]
+ # top_groups = self._get_top_groups(
+ # result.df, dimensions, groupby_series_columns, columns_by_name
+ # )
+ # qry = qry.where(top_groups)
+
+ # qry = qry.select_from(tbl)
+
+ # if is_rowcount:
+ # if not db_engine_spec.allows_subqueries:
+ # raise QueryObjectValidationError(
+ # _("Database does not support subqueries")
+ # )
+ # label = "rowcount"
+ # col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
+ # qry = select([col]).select_from(qry.alias("rowcount_qry"))
+ # labels_expected = [label]
+
+ # return SqlaQuery(
+ # applied_template_filters=applied_template_filters,
+ # cte=cte,
+ # extra_cache_keys=extra_cache_keys,
+ # labels_expected=labels_expected,
+ # sqla_query=qry,
+ # prequeries=prequeries,
+ # )
def _get_series_orderby(
self,
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index fd0a1eff5c..26b07c6e54 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -26,8 +26,8 @@ from typing import (
Any,
cast,
Dict,
+ Hashable,
List,
- Mapping,
NamedTuple,
Optional,
Set,
@@ -87,7 +87,13 @@ from superset.superset_typing import (
QueryObjectDict,
)
from superset.utils import core as utils
-from superset.utils.core import get_user_id
+from superset.utils.core import (
+ GenericDataType,
+ get_column_name,
+ get_user_id,
+ is_adhoc_column,
+ remove_duplicates,
+)
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlMetric, TableColumn
@@ -680,7 +686,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
}
@property
- def query(self) -> str:
+ def fetch_value_predicate(self) -> str:
+ return "fix this!"
+
+ def query(self, query_obj: QueryObjectDict) -> QueryResult:
raise NotImplementedError()
@property
@@ -747,13 +756,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_fetch_values_predicate(self) -> List[Any]:
raise NotImplementedError()
- @staticmethod
- def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
+ def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
raise NotImplementedError()
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
raise NotImplementedError()
+ def get_sqla_row_level_filters(
+ self,
+ template_processor: BaseTemplateProcessor,
+ ) -> List[TextClause]:
+ raise NotImplementedError()
+
def _process_sql_expression( # pylint: disable=no-self-use
self,
expression: Optional[str],
@@ -1156,13 +1170,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def _get_series_orderby(
self,
series_limit_metric: Metric,
- metrics_by_name: Mapping[str, "SqlMetric"],
- columns_by_name: Mapping[str, "TableColumn"],
+ metrics_by_name: Dict[str, "SqlMetric"],
+ columns_by_name: Dict[str, "TableColumn"],
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(
- series_limit_metric, columns_by_name # type: ignore
+ series_limit_metric, columns_by_name
)
elif (
isinstance(series_limit_metric, str)
@@ -1180,23 +1195,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col: Type["AdhocColumn"], # type: ignore
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
- """
- Turn an adhoc column into a sqlalchemy column.
-
- :param col: Adhoc column definition
- :param template_processor: template_processor instance
- :returns: The metric defined as a sqlalchemy column
- :rtype: sqlalchemy.sql.column
- """
- label = utils.get_column_name(col) # type: ignore
- expression = self._process_sql_expression(
- expression=col["sqlExpression"],
- database_id=self.database_id,
- schema=self.schema,
- template_processor=template_processor,
- )
- sqla_column = literal_column(expression)
- return self.make_sqla_column_compatible(sqla_column, label)
+ raise NotImplementedError()
+ # """
+ # Turn an adhoc column into a sqlalchemy column.
+
+ # :param col: Adhoc column definition
+ # :param template_processor: template_processor instance
+ # :returns: The metric defined as a sqlalchemy column
+ # :rtype: sqlalchemy.sql.column
+ # """
+ # label = utils.get_column_name(col) # type: ignore
+ # expression = self._process_sql_expression(
+ # expression=col["sqlExpression"],
+ # database_id=self.database_id,
+ # schema=self.schema,
+ # template_processor=template_processor,
+ # )
+ # sqla_column = literal_column(expression)
+ # return self.make_sqla_column_compatible(sqla_column, label)
def _get_top_groups(
self,
@@ -1371,7 +1387,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"time_column": granularity,
"time_grain": time_grain,
"to_dttm": to_dttm.isoformat() if to_dttm else None,
- "table_columns": [col.get("column_name") for col in self.columns],
+ "table_columns": [col.column_name for col in self.columns],
"filter": filter,
}
columns = columns or []
@@ -1399,11 +1415,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if granularity not in self.dttm_cols and granularity is not None:
granularity = self.main_dttm_col
- columns_by_name: Dict[str, "TableColumn"] = {
- col.get("column_name"): col
- for col in self.columns # col.column_name: col for col in self.columns
+ columns_by_name: Dict[str, TableColumn] = {
+ col.column_name: col for col in self.columns
}
+ metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
+
if not granularity and is_timeseries:
raise QueryObjectValidationError(
_(
@@ -1425,6 +1442,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
template_processor=template_processor,
)
)
+ elif isinstance(metric, str) and metric in metrics_by_name:
+ metrics_exprs.append(
+ metrics_by_name[metric].get_sqla_col(
+ template_processor=template_processor
+ )
+ )
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
@@ -1463,14 +1486,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in columns_by_name:
- gb_column_obj = columns_by_name[col]
- if isinstance(gb_column_obj, dict):
- col = self.get_sqla_col(gb_column_obj)
- else:
- col = gb_column_obj.get_sqla_col()
+ col = columns_by_name[col].get_sqla_col(
+ template_processor=template_processor
+ )
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
+ elif col in metrics_by_name:
+ col = metrics_by_name[col].get_sqla_col(
+ template_processor=template_processor
+ )
+ need_groupby = True
if isinstance(col, ColumnElement):
orderby_exprs.append(col)
@@ -1496,33 +1522,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# if groupby field/expr equals granularity field/expr
if selected == granularity:
table_col = columns_by_name[selected]
- if isinstance(table_col, dict):
- outer = self.get_timestamp_expression(
- column=table_col,
- time_grain=time_grain,
- label=selected,
- template_processor=template_processor,
- )
- else:
- outer = table_col.get_timestamp_expression(
- time_grain=time_grain,
- label=selected,
- template_processor=template_processor,
- )
+ outer = table_col.get_timestamp_expression(
+ time_grain=time_grain,
+ label=selected,
+ template_processor=template_processor,
+ )
# if groupby field equals a selected column
elif selected in columns_by_name:
- if isinstance(columns_by_name[selected], dict):
- outer = sa.column(f"{selected}")
- outer = self.make_sqla_column_compatible(outer, selected)
- else:
- outer = columns_by_name[selected].get_sqla_col()
+ outer = columns_by_name[selected].get_sqla_col(
+ template_processor=template_processor
+ )
else:
- selected = self.validate_adhoc_subquery(
+ selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
- outer = sa.column(f"{selected}")
+ outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
outer = self.adhoc_column_to_sqla(
@@ -1536,19 +1552,27 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
select_exprs.append(outer)
elif columns:
for selected in columns:
- selected = self.validate_adhoc_subquery(
- selected,
+ if is_adhoc_column(selected):
+ _sql = selected["sqlExpression"]
+ _column_label = selected["label"]
+ elif isinstance(selected, str):
+ _sql = selected
+ _column_label = selected
+
+ selected = validate_adhoc_subquery(
+ _sql,
self.database_id,
self.schema,
)
- if isinstance(columns_by_name[selected], dict):
- select_exprs.append(sa.column(f"{selected}"))
- else:
- select_exprs.append(
- columns_by_name[selected].get_sqla_col()
- if selected in columns_by_name
- else self.make_sqla_column_compatible(literal_column(selected))
+ select_exprs.append(
+ columns_by_name[selected].get_sqla_col(
+ template_processor=template_processor
)
+ if isinstance(selected, str) and selected in columns_by_name
+ else self.make_sqla_column_compatible(
+ literal_column(selected), _column_label
+ )
+ )
metrics_exprs = []
if granularity:
@@ -1559,57 +1583,41 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col=granularity,
)
)
- time_filters: List[Any] = []
+ time_filters = []
if is_timeseries:
- if isinstance(dttm_col, dict):
- timestamp = self.get_timestamp_expression(
- dttm_col, time_grain, template_processor=template_processor
- )
- else:
- timestamp = dttm_col.get_timestamp_expression(
- time_grain=time_grain, template_processor=template_processor
- )
+ timestamp = dttm_col.get_timestamp_expression(
+ time_grain=time_grain, template_processor=template_processor
+ )
# always put timestamp as the first column
select_exprs.insert(0, timestamp)
groupby_all_columns[timestamp.name] = timestamp
# Use main dttm column to support index with secondary dttm columns.
- if db_engine_spec.time_secondary_columns:
- if isinstance(dttm_col, dict):
- dttm_col_name = dttm_col.get("column_name")
- else:
- dttm_col_name = dttm_col.column_name
-
- if (
- self.main_dttm_col in self.dttm_cols
- and self.main_dttm_col != dttm_col_name
- ):
- if isinstance(self.main_dttm_col, dict):
- time_filters.append(
- self.get_time_filter(
- self.main_dttm_col,
- from_dttm,
- to_dttm,
- )
- )
- else:
- time_filters.append(
- columns_by_name[self.main_dttm_col].get_time_filter(
- from_dttm,
- to_dttm,
- )
- )
-
- if isinstance(dttm_col, dict):
- time_filters.append(self.get_time_filter(dttm_col, from_dttm, to_dttm))
- else:
- time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
+ if (
+ db_engine_spec.time_secondary_columns
+ and self.main_dttm_col in self.dttm_cols
+ and self.main_dttm_col != dttm_col.column_name
+ ):
+ time_filters.append(
+ columns_by_name[self.main_dttm_col].get_time_filter(
+ start_dttm=from_dttm,
+ end_dttm=to_dttm,
+ template_processor=template_processor,
+ )
+ )
+ time_filters.append(
+ dttm_col.get_time_filter(
+ start_dttm=from_dttm,
+ end_dttm=to_dttm,
+ template_processor=template_processor,
+ )
+ )
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
# raw columns as custom SQL adhoc metric).
- select_exprs = utils.remove_duplicates(
+ select_exprs = remove_duplicates(
select_exprs + metrics_exprs, key=lambda x: x.name
)
@@ -1619,7 +1627,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# Order by columns are "hidden" columns, some databases require them
# always be present in SELECT if an aggregation function is used
if not db_engine_spec.allows_hidden_ordeby_agg:
- select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs)
+ select_exprs = remove_duplicates(select_exprs + orderby_exprs)
qry = sa.select(select_exprs)
@@ -1637,18 +1645,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
flt_col = flt["col"]
val = flt.get("val")
op = flt["op"].upper()
- col_obj: Optional["TableColumn"] = None
+ col_obj: Optional[TableColumn] = None
sqla_col: Optional[Column] = None
if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
col_obj = dttm_col
- elif utils.is_adhoc_column(flt_col):
- sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore
+ elif is_adhoc_column(flt_col):
+ sqla_col = self.adhoc_column_to_sqla(col=flt_col, template_processor=template_processor) # type: ignore
else:
col_obj = columns_by_name.get(flt_col)
filter_grain = flt.get("grain")
if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
- if utils.get_column_name(flt_col) in removed_filters:
+ if get_column_name(flt_col) in removed_filters:
# Skip generating SQLA filter when the jinja template handles it.
continue
@@ -1656,44 +1664,29 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if sqla_col is not None:
pass
elif col_obj and filter_grain:
- if isinstance(col_obj, dict):
- sqla_col = self.get_timestamp_expression(
- col_obj, time_grain, template_processor=template_processor
- )
- else:
- sqla_col = col_obj.get_timestamp_expression(
- time_grain=filter_grain,
- template_processor=template_processor,
- )
- elif col_obj and isinstance(col_obj, dict):
- sqla_col = sa.column(col_obj.get("column_name"))
+ sqla_col = col_obj.get_timestamp_expression(
+ time_grain=filter_grain, template_processor=template_processor
+ )
elif col_obj:
- sqla_col = col_obj.get_sqla_col()
-
- if col_obj and isinstance(col_obj, dict):
- col_type = col_obj.get("type")
- else:
- col_type = col_obj.type if col_obj else None
+ sqla_col = col_obj.get_sqla_col(
+ template_processor=template_processor
+ )
+ col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(
native_type=col_type,
- db_extra=self.database.get_extra(), # type: ignore
+# db_extra=self.database.get_extra(),
)
is_list_target = op in (
utils.FilterOperator.IN.value,
utils.FilterOperator.NOT_IN.value,
)
- if col_obj and isinstance(col_obj, dict):
- col_advanced_data_type = ""
- else:
- col_advanced_data_type = (
- col_obj.advanced_data_type if col_obj else ""
- )
+ col_advanced_data_type = col_obj.advanced_data_type if col_obj else ""
if col_spec and not col_advanced_data_type:
target_generic_type = col_spec.generic_type
else:
- target_generic_type = utils.GenericDataType.STRING
+ target_generic_type = GenericDataType.STRING
eq = self.filter_values_handler(
values=val,
operator=op,
@@ -1701,7 +1694,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
target_native_type=col_type,
is_list_target=is_list_target,
db_engine_spec=db_engine_spec,
- db_extra=self.database.get_extra(), # type: ignore
+# db_extra=self.database.get_extra(),
)
if (
col_advanced_data_type != ""
@@ -1757,7 +1750,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
elif op == utils.FilterOperator.IS_FALSE.value:
where_clause_and.append(sqla_col.is_(False))
else:
- if eq is None:
+ if (
+ op
+ not in {
+ utils.FilterOperator.EQUALS.value,
+ utils.FilterOperator.NOT_EQUALS.value,
+ }
+ and eq is None
+ ):
raise QueryObjectValidationError(
_(
"Must specify a value for filters "
@@ -1791,23 +1791,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
extras=extras,
)
where_clause_and.append(
- self.get_time_filter(
- time_col=col_obj,
+ col_obj.get_time_filter(
start_dttm=_since,
end_dttm=_until,
+ label=sqla_col.key,
+ template_processor=template_processor,
)
)
else:
raise QueryObjectValidationError(
_("Invalid filter operation type: %(op)s", op=op)
)
- # todo(hugh): fix this w/ template_processor
- # where_clause_and += self.get_sqla_row_level_filters(template_processor)
+ where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
try:
- where = template_processor.process_template(f"{where}")
+ where = template_processor.process_template(f"({where})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@@ -1815,11 +1815,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
msg=ex.message,
)
) from ex
+ where = self._process_sql_expression(
+ expression=where,
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor
+ )
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
try:
- having = template_processor.process_template(f"{having}")
+ having = template_processor.process_template(f"({having})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@@ -1827,9 +1833,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
msg=ex.message,
)
) from ex
+ having = self._process_sql_expression(
+ expression=having,
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor,
+ )
having_clause_and += [self.text(having)]
- if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore
- qry = qry.where(self.get_fetch_values_predicate()) # type: ignore
+
+ if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore
+ qry = qry.where(
+ self.get_fetch_values_predicate(template_processor=template_processor) # type: ignore
+ )
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and)))
else:
@@ -1869,7 +1884,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
inner_groupby_exprs = []
inner_select_exprs = []
for gby_name, gby_obj in groupby_series_columns.items():
- label = utils.get_column_name(gby_name)
+ label = get_column_name(gby_name)
inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
@@ -1879,26 +1894,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
inner_time_filter = []
if dttm_col and not db_engine_spec.time_groupby_inline:
- if isinstance(dttm_col, dict):
- inner_time_filter = [
- self.get_time_filter(
- dttm_col,
- inner_from_dttm or from_dttm,
- inner_to_dttm or to_dttm,
- )
- ]
- else:
- inner_time_filter = [
- dttm_col.get_time_filter(
- inner_from_dttm or from_dttm,
- inner_to_dttm or to_dttm,
- )
- ]
-
+ inner_time_filter = [
+ dttm_col.get_time_filter(
+ start_dttm=inner_from_dttm or from_dttm,
+ end_dttm=inner_to_dttm or to_dttm,
+ template_processor=template_processor,
+ )
+ ]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
subq = subq.group_by(*inner_groupby_exprs)
ob = inner_main_metric_expr
+ if series_limit_metric:
+ ob = self._get_series_orderby(
+ series_limit_metric=series_limit_metric,
+ metrics_by_name=metrics_by_name,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
+ )
direction = sa.desc if order_desc else sa.asc
subq = subq.order_by(direction(ob))
subq = subq.limit(series_limit)
@@ -1912,6 +1925,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
on_clause.append(gby_obj == sa.column(col_name))
tbl = tbl.join(subq.alias(), and_(*on_clause))
+ else:
+ if series_limit_metric:
+ orderby = [
+ (
+ self._get_series_orderby(
+ series_limit_metric=series_limit_metric,
+ metrics_by_name=metrics_by_name,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
+ ),
+ not order_desc,
+ )
+ ]
# run prequery to get top groups
prequery_obj = {
@@ -1928,7 +1954,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"columns": columns,
"order_desc": True,
}
- result = self.exc_query(prequery_obj)
+
+ result = self.query(prequery_obj)
prequeries.append(result.query)
dimensions = [
c
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index babea35baf..5ccba99975 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -33,6 +33,7 @@ from sqlalchemy import (
DateTime,
Enum,
ForeignKey,
+ Hashable,
Integer,
Numeric,
String,
@@ -307,7 +308,7 @@ class Query(
return ""
@staticmethod
- def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
+ def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]:
return []
@property
diff --git a/superset/result_set.py b/superset/result_set.py
index 3d29673b9f..63d48b1e4b 100644
--- a/superset/result_set.py
+++ b/superset/result_set.py
@@ -70,9 +70,9 @@ def stringify_values(array: NDArray[Any]) -> NDArray[Any]:
for obj in it:
if na_obj := pd.isna(obj):
# pandas <NA> type cannot be converted to string
- obj[na_obj] = None # type: ignore
+ obj[na_obj] = None
else:
- obj[...] = stringify(obj) # type: ignore
+ obj[...] = stringify(obj)
return result
diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py
index 673c39ebf3..e2706345b1 100644
--- a/superset/utils/pandas_postprocessing/boxplot.py
+++ b/superset/utils/pandas_postprocessing/boxplot.py
@@ -57,10 +57,10 @@ def boxplot(
"""
def quartile1(series: Series) -> float:
- return np.nanpercentile(series, 25, interpolation="midpoint") # type: ignore
+ return np.nanpercentile(series, 25, interpolation="midpoint")
def quartile3(series: Series) -> float:
- return np.nanpercentile(series, 75, interpolation="midpoint") # type: ignore
+ return np.nanpercentile(series, 75, interpolation="midpoint")
if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY:
@@ -99,8 +99,8 @@ def boxplot(
return np.nanpercentile(series, low)
else:
- whisker_high = np.max # type: ignore
- whisker_low = np.min # type: ignore
+ whisker_high = np.max
+ whisker_low = np.min
def outliers(series: Series) -> Set[float]:
above = series[series > whisker_high(series)]
diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py
index 1026164e45..db783c4bed 100644
--- a/superset/utils/pandas_postprocessing/flatten.py
+++ b/superset/utils/pandas_postprocessing/flatten.py
@@ -85,7 +85,7 @@ def flatten(
_columns = []
for series in df.columns.to_flat_index():
_cells = []
- for cell in series if is_sequence(series) else [series]: # type: ignore
+ for cell in series if is_sequence(series) else [series]:
if pd.notnull(cell):
# every cell should be converted to string and escape comma
_cells.append(escape_separator(str(cell)))