You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2023/04/10 22:40:38 UTC
[superset] 07/07: chore: Refactor ExploreMixin to power both Datasets (SqlaTable) and Query models (#22853)
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch 2023.13.1
in repository https://gitbox.apache.org/repos/asf/superset.git
commit ba7b6fd7d9a5eb02ca997965f159e46685f7e2d4
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Mon Apr 10 16:32:52 2023 -0400
chore: Refactor ExploreMixin to power both Datasets (SqlaTable) and Query models (#22853)
---
.../SqlLab/components/SaveDatasetModal/index.tsx | 6 +-
superset-frontend/src/SqlLab/fixtures.ts | 6 +-
.../controls/MetricControl/AdhocMetricOption.jsx | 2 +-
superset/connectors/sqla/models.py | 766 ++-------------------
superset/models/helpers.py | 487 ++++++++-----
superset/models/sql_lab.py | 65 +-
superset/utils/core.py | 12 +-
superset/views/core.py | 2 +-
tests/integration_tests/charts/data/api_tests.py | 4 +-
tests/integration_tests/sqllab_tests.py | 24 +-
10 files changed, 413 insertions(+), 961 deletions(-)
diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx
index 949323b9aa..402e26462e 100644
--- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx
+++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx
@@ -61,7 +61,7 @@ export type ExploreQuery = QueryResponse & {
};
export interface ISimpleColumn {
- name?: string | null;
+ column_name?: string | null;
type?: string | null;
is_dttm?: boolean | null;
}
@@ -216,7 +216,7 @@ export const SaveDatasetModal = ({
...formDataWithDefaults,
datasource: `${datasetToOverwrite.datasetid}__table`,
...(defaultVizType === 'table' && {
- all_columns: datasource?.columns?.map(column => column.name),
+ all_columns: datasource?.columns?.map(column => column.column_name),
}),
}),
]);
@@ -301,7 +301,7 @@ export const SaveDatasetModal = ({
...formDataWithDefaults,
datasource: `${data.table_id}__table`,
...(defaultVizType === 'table' && {
- all_columns: selectedColumns.map(column => column.name),
+ all_columns: selectedColumns.map(column => column.column_name),
}),
}),
)
diff --git a/superset-frontend/src/SqlLab/fixtures.ts b/superset-frontend/src/SqlLab/fixtures.ts
index fcb0fff8e3..ba88a41b0a 100644
--- a/superset-frontend/src/SqlLab/fixtures.ts
+++ b/superset-frontend/src/SqlLab/fixtures.ts
@@ -692,17 +692,17 @@ export const testQuery: ISaveableDatasource = {
sql: 'SELECT *',
columns: [
{
- name: 'Column 1',
+ column_name: 'Column 1',
type: DatasourceType.Query,
is_dttm: false,
},
{
- name: 'Column 3',
+ column_name: 'Column 3',
type: DatasourceType.Query,
is_dttm: false,
},
{
- name: 'Column 2',
+ column_name: 'Column 2',
type: DatasourceType.Query,
is_dttm: true,
},
diff --git a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx
index 80cf879f7f..c74212f0ba 100644
--- a/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx
+++ b/superset-frontend/src/explore/components/controls/MetricControl/AdhocMetricOption.jsx
@@ -48,7 +48,7 @@ class AdhocMetricOption extends React.PureComponent {
}
onRemoveMetric(e) {
- e.stopPropagation();
+ e?.stopPropagation();
this.props.onRemoveMetric(this.props.index);
}
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index b76e423caf..fd1276f592 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -31,7 +31,6 @@ from typing import (
Dict,
Hashable,
List,
- NamedTuple,
Optional,
Set,
Tuple,
@@ -50,11 +49,9 @@ from flask_babel import lazy_gettext as _
from jinja2.exceptions import TemplateError
from sqlalchemy import (
and_,
- asc,
Boolean,
Column,
DateTime,
- desc,
Enum,
ForeignKey,
inspect,
@@ -80,13 +77,11 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
-from sqlalchemy.sql.expression import Label, Select, TextAsFrom
+from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
from superset import app, db, is_feature_enabled, security_manager
-from superset.advanced_data_type.types import AdvancedDataTypeResponse
from superset.common.db_query_status import QueryStatus
-from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.connectors.sqla.utils import (
find_cached_objects_in_session,
@@ -98,7 +93,6 @@ from superset.connectors.sqla.utils import (
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import (
- AdvancedDataTypeResponseError,
ColumnNotFoundException,
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
@@ -106,7 +100,6 @@ from superset.exceptions import (
SupersetGenericDBErrorException,
SupersetSecurityException,
)
-from superset.extensions import feature_flag_manager
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
@@ -114,26 +107,17 @@ from superset.jinja_context import (
)
from superset.models.annotations import Annotation
from superset.models.core import Database
-from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult
-from superset.sql_parse import ParsedQuery, sanitize_clause
-from superset.superset_typing import (
- AdhocColumn,
- AdhocMetric,
- Column as ColumnTyping,
- Metric,
- OrderBy,
- QueryObjectDict,
+from superset.models.helpers import (
+ AuditMixinNullable,
+ CertificationMixin,
+ ExploreMixin,
+ QueryResult,
+ QueryStringExtended,
)
+from superset.sql_parse import ParsedQuery, sanitize_clause
+from superset.superset_typing import AdhocColumn, AdhocMetric, Metric, QueryObjectDict
from superset.utils import core as utils
-from superset.utils.core import (
- GenericDataType,
- get_column_name,
- get_username,
- is_adhoc_column,
- MediumText,
- QueryObjectFilterClause,
- remove_duplicates,
-)
+from superset.utils.core import GenericDataType, get_username, MediumText
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -150,26 +134,6 @@ ADDITIVE_METRIC_TYPES = {
ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
-class SqlaQuery(NamedTuple):
- applied_template_filters: List[str]
- applied_filter_columns: List[ColumnTyping]
- rejected_filter_columns: List[ColumnTyping]
- cte: Optional[str]
- extra_cache_keys: List[Any]
- labels_expected: List[str]
- prequeries: List[str]
- sqla_query: Select
-
-
-class QueryStringExtended(NamedTuple):
- applied_template_filters: Optional[List[str]]
- applied_filter_columns: List[ColumnTyping]
- rejected_filter_columns: List[ColumnTyping]
- labels_expected: List[str]
- prequeries: List[str]
- sql: str
-
-
@dataclass
class MetadataResult:
added: List[str] = field(default_factory=list)
@@ -310,6 +274,35 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def type_generic(self) -> Optional[utils.GenericDataType]:
if self.is_dttm:
return GenericDataType.TEMPORAL
+
+ bool_types = ("BOOL",)
+ num_types = (
+ "DOUBLE",
+ "FLOAT",
+ "INT",
+ "BIGINT",
+ "NUMBER",
+ "LONG",
+ "REAL",
+ "NUMERIC",
+ "DECIMAL",
+ "MONEY",
+ )
+ date_types = ("DATE", "TIME")
+ str_types = ("VARCHAR", "STRING", "CHAR")
+
+ if self.table is None:
+ # Query.TableColumns don't have a reference to a table.db_engine_spec
+ # reference so this logic will manage rendering types
+ if self.type and any(map(lambda t: t in self.type.upper(), str_types)):
+ return GenericDataType.STRING
+ if self.type and any(map(lambda t: t in self.type.upper(), bool_types)):
+ return GenericDataType.BOOLEAN
+ if self.type and any(map(lambda t: t in self.type.upper(), num_types)):
+ return GenericDataType.NUMERIC
+ if self.type and any(map(lambda t: t in self.type.upper(), date_types)):
+ return GenericDataType.TEMPORAL
+
column_spec = self.db_engine_spec.get_column_spec(
self.type, db_extra=self.db_extra
)
@@ -545,8 +538,10 @@ def _process_sql_expression(
return expression
-class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
- """An ORM object for SqlAlchemy table references."""
+class SqlaTable(
+ Model, BaseDatasource, ExploreMixin
+): # pylint: disable=too-many-public-methods
+ """An ORM object for SqlAlchemy table references"""
type = "table"
query_language = "sql"
@@ -626,6 +621,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def __repr__(self) -> str: # pylint: disable=invalid-repr-returned
return self.name
+ @property
+ def db_extra(self) -> Dict[str, Any]:
+ return self.database.get_extra()
+
@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
"""
@@ -1009,6 +1008,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def adhoc_column_to_sqla(
self,
col: AdhocColumn,
+ force_type_check: bool = False,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
"""
@@ -1147,676 +1147,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def text(self, clause: str) -> TextClause:
return self.db_engine_spec.get_text_clause(clause)
- def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
- self,
- apply_fetch_values_predicate: bool = False,
- columns: Optional[List[ColumnTyping]] = None,
- extras: Optional[Dict[str, Any]] = None,
- filter: Optional[ # pylint: disable=redefined-builtin
- List[QueryObjectFilterClause]
- ] = None,
- from_dttm: Optional[datetime] = None,
- granularity: Optional[str] = None,
- groupby: Optional[List[Column]] = None,
- inner_from_dttm: Optional[datetime] = None,
- inner_to_dttm: Optional[datetime] = None,
- is_rowcount: bool = False,
- is_timeseries: bool = True,
- metrics: Optional[List[Metric]] = None,
- orderby: Optional[List[OrderBy]] = None,
- order_desc: bool = True,
- to_dttm: Optional[datetime] = None,
- series_columns: Optional[List[Column]] = None,
- series_limit: Optional[int] = None,
- series_limit_metric: Optional[Metric] = None,
- row_limit: Optional[int] = None,
- row_offset: Optional[int] = None,
- timeseries_limit: Optional[int] = None,
- timeseries_limit_metric: Optional[Metric] = None,
- time_shift: Optional[str] = None,
- ) -> SqlaQuery:
- """Querying any sqla table from this common interface"""
- if granularity not in self.dttm_cols and granularity is not None:
- granularity = self.main_dttm_col
-
- extras = extras or {}
- time_grain = extras.get("time_grain_sqla")
-
- template_kwargs = {
- "columns": columns,
- "from_dttm": from_dttm.isoformat() if from_dttm else None,
- "groupby": groupby,
- "metrics": metrics,
- "row_limit": row_limit,
- "row_offset": row_offset,
- "time_column": granularity,
- "time_grain": time_grain,
- "to_dttm": to_dttm.isoformat() if to_dttm else None,
- "table_columns": [col.column_name for col in self.columns],
- "filter": filter,
- }
- columns = columns or []
- groupby = groupby or []
- rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
- applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
- series_column_names = utils.get_column_names(series_columns or [])
- # deprecated, to be removed in 2.0
- if is_timeseries and timeseries_limit:
- series_limit = timeseries_limit
- series_limit_metric = series_limit_metric or timeseries_limit_metric
- template_kwargs.update(self.template_params_dict)
- extra_cache_keys: List[Any] = []
- template_kwargs["extra_cache_keys"] = extra_cache_keys
- removed_filters: List[str] = []
- applied_template_filters: List[str] = []
- template_kwargs["removed_filters"] = removed_filters
- template_kwargs["applied_filters"] = applied_template_filters
- template_processor = self.get_template_processor(**template_kwargs)
- db_engine_spec = self.db_engine_spec
- prequeries: List[str] = []
- orderby = orderby or []
- need_groupby = bool(metrics is not None or groupby)
- metrics = metrics or []
-
- # For backward compatibility
- if granularity not in self.dttm_cols and granularity is not None:
- granularity = self.main_dttm_col
-
- columns_by_name: Dict[str, TableColumn] = {
- col.column_name: col for col in self.columns
- }
-
- metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
-
- if not granularity and is_timeseries:
- raise QueryObjectValidationError(
- _(
- "Datetime column not provided as part table configuration "
- "and is required by this type of chart"
- )
- )
- if not metrics and not columns and not groupby:
- raise QueryObjectValidationError(_("Empty query?"))
-
- metrics_exprs: List[ColumnElement] = []
- for metric in metrics:
- if utils.is_adhoc_metric(metric):
- assert isinstance(metric, dict)
- metrics_exprs.append(
- self.adhoc_metric_to_sqla(
- metric=metric,
- columns_by_name=columns_by_name,
- template_processor=template_processor,
- )
- )
- elif isinstance(metric, str) and metric in metrics_by_name:
- metrics_exprs.append(
- metrics_by_name[metric].get_sqla_col(
- template_processor=template_processor
- )
- )
- else:
- raise QueryObjectValidationError(
- _("Metric '%(metric)s' does not exist", metric=metric)
- )
-
- if metrics_exprs:
- main_metric_expr = metrics_exprs[0]
- else:
- main_metric_expr, label = literal_column("COUNT(*)"), "count"
- main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
-
- # To ensure correct handling of the ORDER BY labeling we need to reference the
- # metric instance if defined in the SELECT clause.
- # use the key of the ColumnClause for the expected label
- metrics_exprs_by_label = {m.key: m for m in metrics_exprs}
- metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
-
- # Since orderby may use adhoc metrics, too; we need to process them first
- orderby_exprs: List[ColumnElement] = []
- for orig_col, ascending in orderby:
- col: Union[AdhocMetric, ColumnElement] = orig_col
- if isinstance(col, dict):
- col = cast(AdhocMetric, col)
- if col.get("sqlExpression"):
- col["sqlExpression"] = _process_sql_expression(
- expression=col["sqlExpression"],
- database_id=self.database_id,
- schema=self.schema,
- template_processor=template_processor,
- )
- if utils.is_adhoc_metric(col):
- # add adhoc sort by column to columns_by_name if not exists
- col = self.adhoc_metric_to_sqla(col, columns_by_name)
- # if the adhoc metric has been defined before
- # use the existing instance.
- col = metrics_exprs_by_expr.get(str(col), col)
- need_groupby = True
- elif col in columns_by_name:
- col = columns_by_name[col].get_sqla_col(
- template_processor=template_processor
- )
- elif col in metrics_exprs_by_label:
- col = metrics_exprs_by_label[col]
- need_groupby = True
- elif col in metrics_by_name:
- col = metrics_by_name[col].get_sqla_col(
- template_processor=template_processor
- )
- need_groupby = True
-
- if isinstance(col, ColumnElement):
- orderby_exprs.append(col)
- else:
- # Could not convert a column reference to valid ColumnElement
- raise QueryObjectValidationError(
- _("Unknown column used in orderby: %(col)s", col=orig_col)
- )
-
- select_exprs: List[Union[Column, Label]] = []
- groupby_all_columns = {}
- groupby_series_columns = {}
-
- # filter out the pseudo column __timestamp from columns
- columns = [col for col in columns if col != utils.DTTM_ALIAS]
- dttm_col = columns_by_name.get(granularity) if granularity else None
-
- if need_groupby:
- # dedup columns while preserving order
- columns = groupby or columns
- for selected in columns:
- if isinstance(selected, str):
- # if groupby field/expr equals granularity field/expr
- if selected == granularity:
- table_col = columns_by_name[selected]
- outer = table_col.get_timestamp_expression(
- time_grain=time_grain,
- label=selected,
- template_processor=template_processor,
- )
- # if groupby field equals a selected column
- elif selected in columns_by_name:
- outer = columns_by_name[selected].get_sqla_col(
- template_processor=template_processor
- )
- else:
- selected = validate_adhoc_subquery(
- selected,
- self.database_id,
- self.schema,
- )
- outer = literal_column(f"({selected})")
- outer = self.make_sqla_column_compatible(outer, selected)
- else:
- outer = self.adhoc_column_to_sqla(
- col=selected, template_processor=template_processor
- )
- groupby_all_columns[outer.name] = outer
- if (
- is_timeseries and not series_column_names
- ) or outer.name in series_column_names:
- groupby_series_columns[outer.name] = outer
- select_exprs.append(outer)
- elif columns:
- for selected in columns:
- if is_adhoc_column(selected):
- _sql = selected["sqlExpression"]
- _column_label = selected["label"]
- elif isinstance(selected, str):
- _sql = selected
- _column_label = selected
-
- selected = validate_adhoc_subquery(
- _sql,
- self.database_id,
- self.schema,
- )
- select_exprs.append(
- columns_by_name[selected].get_sqla_col(
- template_processor=template_processor
- )
- if isinstance(selected, str) and selected in columns_by_name
- else self.make_sqla_column_compatible(
- literal_column(selected), _column_label
- )
- )
- metrics_exprs = []
-
- if granularity:
- if granularity not in columns_by_name or not dttm_col:
- raise QueryObjectValidationError(
- _(
- 'Time column "%(col)s" does not exist in dataset',
- col=granularity,
- )
- )
- time_filters = []
-
- if is_timeseries:
- timestamp = dttm_col.get_timestamp_expression(
- time_grain=time_grain, template_processor=template_processor
- )
- # always put timestamp as the first column
- select_exprs.insert(0, timestamp)
- groupby_all_columns[timestamp.name] = timestamp
-
- # Use main dttm column to support index with secondary dttm columns.
- if (
- db_engine_spec.time_secondary_columns
- and self.main_dttm_col in self.dttm_cols
- and self.main_dttm_col != dttm_col.column_name
- ):
- time_filters.append(
- columns_by_name[self.main_dttm_col].get_time_filter(
- start_dttm=from_dttm,
- end_dttm=to_dttm,
- template_processor=template_processor,
- )
- )
- time_filters.append(
- dttm_col.get_time_filter(
- start_dttm=from_dttm,
- end_dttm=to_dttm,
- template_processor=template_processor,
- )
- )
-
- # Always remove duplicates by column name, as sometimes `metrics_exprs`
- # can have the same name as a groupby column (e.g. when users use
- # raw columns as custom SQL adhoc metric).
- select_exprs = remove_duplicates(
- select_exprs + metrics_exprs, key=lambda x: x.name
- )
-
- # Expected output columns
- labels_expected = [c.key for c in select_exprs]
-
- # Order by columns are "hidden" columns, some databases require them
- # always be present in SELECT if an aggregation function is used
- if not db_engine_spec.allows_hidden_orderby_agg:
- select_exprs = remove_duplicates(select_exprs + orderby_exprs)
-
- qry = sa.select(select_exprs)
-
- tbl, cte = self.get_from_clause(template_processor)
-
- if groupby_all_columns:
- qry = qry.group_by(*groupby_all_columns.values())
-
- where_clause_and = []
- having_clause_and = []
-
- for flt in filter: # type: ignore
- if not all(flt.get(s) for s in ["col", "op"]):
- continue
- flt_col = flt["col"]
- val = flt.get("val")
- op = flt["op"].upper()
- col_obj: Optional[TableColumn] = None
- sqla_col: Optional[Column] = None
- if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
- col_obj = dttm_col
- elif is_adhoc_column(flt_col):
- try:
- sqla_col = self.adhoc_column_to_sqla(flt_col)
- applied_adhoc_filters_columns.append(flt_col)
- except ColumnNotFoundException:
- rejected_adhoc_filters_columns.append(flt_col)
- continue
- else:
- col_obj = columns_by_name.get(cast(str, flt_col))
- filter_grain = flt.get("grain")
-
- if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
- if get_column_name(flt_col) in removed_filters:
- # Skip generating SQLA filter when the jinja template handles it.
- continue
-
- if col_obj or sqla_col is not None:
- if sqla_col is not None:
- pass
- elif col_obj and filter_grain:
- sqla_col = col_obj.get_timestamp_expression(
- time_grain=filter_grain, template_processor=template_processor
- )
- elif col_obj:
- sqla_col = col_obj.get_sqla_col(
- template_processor=template_processor
- )
- col_type = col_obj.type if col_obj else None
- col_spec = db_engine_spec.get_column_spec(
- native_type=col_type,
- db_extra=self.database.get_extra(),
- )
- is_list_target = op in (
- utils.FilterOperator.IN.value,
- utils.FilterOperator.NOT_IN.value,
- )
-
- col_advanced_data_type = col_obj.advanced_data_type if col_obj else ""
-
- if col_spec and not col_advanced_data_type:
- target_generic_type = col_spec.generic_type
- else:
- target_generic_type = GenericDataType.STRING
- eq = self.filter_values_handler(
- values=val,
- operator=op,
- target_generic_type=target_generic_type,
- target_native_type=col_type,
- is_list_target=is_list_target,
- db_engine_spec=db_engine_spec,
- db_extra=self.database.get_extra(),
- )
- if (
- col_advanced_data_type != ""
- and feature_flag_manager.is_feature_enabled(
- "ENABLE_ADVANCED_DATA_TYPES"
- )
- and col_advanced_data_type in ADVANCED_DATA_TYPES
- ):
- values = eq if is_list_target else [eq] # type: ignore
- bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
- col_advanced_data_type
- ].translate_type(
- {
- "type": col_advanced_data_type,
- "values": values,
- }
- )
- if bus_resp["error_message"]:
- raise AdvancedDataTypeResponseError(
- _(bus_resp["error_message"])
- )
-
- where_clause_and.append(
- ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
- sqla_col, op, bus_resp["values"]
- )
- )
- elif is_list_target:
- assert isinstance(eq, (tuple, list))
- if len(eq) == 0:
- raise QueryObjectValidationError(
- _("Filter value list cannot be empty")
- )
- if len(eq) > len(
- eq_without_none := [x for x in eq if x is not None]
- ):
- is_null_cond = sqla_col.is_(None)
- if eq:
- cond = or_(is_null_cond, sqla_col.in_(eq_without_none))
- else:
- cond = is_null_cond
- else:
- cond = sqla_col.in_(eq)
- if op == utils.FilterOperator.NOT_IN.value:
- cond = ~cond
- where_clause_and.append(cond)
- elif op == utils.FilterOperator.IS_NULL.value:
- where_clause_and.append(sqla_col.is_(None))
- elif op == utils.FilterOperator.IS_NOT_NULL.value:
- where_clause_and.append(sqla_col.isnot(None))
- elif op == utils.FilterOperator.IS_TRUE.value:
- where_clause_and.append(sqla_col.is_(True))
- elif op == utils.FilterOperator.IS_FALSE.value:
- where_clause_and.append(sqla_col.is_(False))
- else:
- if (
- op
- not in {
- utils.FilterOperator.EQUALS.value,
- utils.FilterOperator.NOT_EQUALS.value,
- }
- and eq is None
- ):
- raise QueryObjectValidationError(
- _(
- "Must specify a value for filters "
- "with comparison operators"
- )
- )
- if op == utils.FilterOperator.EQUALS.value:
- where_clause_and.append(sqla_col == eq)
- elif op == utils.FilterOperator.NOT_EQUALS.value:
- where_clause_and.append(sqla_col != eq)
- elif op == utils.FilterOperator.GREATER_THAN.value:
- where_clause_and.append(sqla_col > eq)
- elif op == utils.FilterOperator.LESS_THAN.value:
- where_clause_and.append(sqla_col < eq)
- elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value:
- where_clause_and.append(sqla_col >= eq)
- elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value:
- where_clause_and.append(sqla_col <= eq)
- elif op == utils.FilterOperator.LIKE.value:
- where_clause_and.append(sqla_col.like(eq))
- elif op == utils.FilterOperator.ILIKE.value:
- where_clause_and.append(sqla_col.ilike(eq))
- elif (
- op == utils.FilterOperator.TEMPORAL_RANGE.value
- and isinstance(eq, str)
- and col_obj is not None
- ):
- _since, _until = get_since_until_from_time_range(
- time_range=eq,
- time_shift=time_shift,
- extras=extras,
- )
- where_clause_and.append(
- col_obj.get_time_filter(
- start_dttm=_since,
- end_dttm=_until,
- label=sqla_col.key,
- template_processor=template_processor,
- )
- )
- else:
- raise QueryObjectValidationError(
- _("Invalid filter operation type: %(op)s", op=op)
- )
- where_clause_and += self.get_sqla_row_level_filters(template_processor)
- if extras:
- where = extras.get("where")
- if where:
- try:
- where = template_processor.process_template(f"({where})")
- except TemplateError as ex:
- raise QueryObjectValidationError(
- _(
- "Error in jinja expression in WHERE clause: %(msg)s",
- msg=ex.message,
- )
- ) from ex
- where = _process_sql_expression(
- expression=where,
- database_id=self.database_id,
- schema=self.schema,
- )
- where_clause_and += [self.text(where)]
- having = extras.get("having")
- if having:
- try:
- having = template_processor.process_template(f"({having})")
- except TemplateError as ex:
- raise QueryObjectValidationError(
- _(
- "Error in jinja expression in HAVING clause: %(msg)s",
- msg=ex.message,
- )
- ) from ex
- having = _process_sql_expression(
- expression=having,
- database_id=self.database_id,
- schema=self.schema,
- )
- having_clause_and += [self.text(having)]
-
- if apply_fetch_values_predicate and self.fetch_values_predicate:
- qry = qry.where(
- self.get_fetch_values_predicate(template_processor=template_processor)
- )
- if granularity:
- qry = qry.where(and_(*(time_filters + where_clause_and)))
- else:
- qry = qry.where(and_(*where_clause_and))
- qry = qry.having(and_(*having_clause_and))
-
- self.make_orderby_compatible(select_exprs, orderby_exprs)
-
- for col, (orig_col, ascending) in zip(orderby_exprs, orderby):
- if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
- # if engine does not allow using SELECT alias in ORDER BY
- # revert to the underlying column
- col = col.element
-
- if (
- db_engine_spec.allows_alias_in_select
- and db_engine_spec.allows_hidden_cc_in_orderby
- and col.name in [select_col.name for select_col in select_exprs]
- ):
- col = literal_column(col.name)
- direction = asc if ascending else desc
- qry = qry.order_by(direction(col))
-
- if row_limit:
- qry = qry.limit(row_limit)
- if row_offset:
- qry = qry.offset(row_offset)
-
- if series_limit and groupby_series_columns:
- if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries:
- # some sql dialects require for order by expressions
- # to also be in the select clause -- others, e.g. vertica,
- # require a unique inner alias
- inner_main_metric_expr = self.make_sqla_column_compatible(
- main_metric_expr, "mme_inner__"
- )
- inner_groupby_exprs = []
- inner_select_exprs = []
- for gby_name, gby_obj in groupby_series_columns.items():
- label = get_column_name(gby_name)
- inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
- inner_groupby_exprs.append(inner)
- inner_select_exprs.append(inner)
-
- inner_select_exprs += [inner_main_metric_expr]
- subq = select(inner_select_exprs).select_from(tbl)
- inner_time_filter = []
-
- if dttm_col and not db_engine_spec.time_groupby_inline:
- inner_time_filter = [
- dttm_col.get_time_filter(
- start_dttm=inner_from_dttm or from_dttm,
- end_dttm=inner_to_dttm or to_dttm,
- template_processor=template_processor,
- )
- ]
- subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
- subq = subq.group_by(*inner_groupby_exprs)
-
- ob = inner_main_metric_expr
- if series_limit_metric:
- ob = self._get_series_orderby(
- series_limit_metric=series_limit_metric,
- metrics_by_name=metrics_by_name,
- columns_by_name=columns_by_name,
- template_processor=template_processor,
- )
- direction = desc if order_desc else asc
- subq = subq.order_by(direction(ob))
- subq = subq.limit(series_limit)
-
- on_clause = []
- for gby_name, gby_obj in groupby_series_columns.items():
- # in this case the column name, not the alias, needs to be
- # conditionally mutated, as it refers to the column alias in
- # the inner query
- col_name = db_engine_spec.make_label_compatible(gby_name + "__")
- on_clause.append(gby_obj == column(col_name))
-
- tbl = tbl.join(subq.alias(), and_(*on_clause))
- else:
- if series_limit_metric:
- orderby = [
- (
- self._get_series_orderby(
- series_limit_metric=series_limit_metric,
- metrics_by_name=metrics_by_name,
- columns_by_name=columns_by_name,
- template_processor=template_processor,
- ),
- not order_desc,
- )
- ]
-
- # run prequery to get top groups
- prequery_obj = {
- "is_timeseries": False,
- "row_limit": series_limit,
- "metrics": metrics,
- "granularity": granularity,
- "groupby": groupby,
- "from_dttm": inner_from_dttm or from_dttm,
- "to_dttm": inner_to_dttm or to_dttm,
- "filter": filter,
- "orderby": orderby,
- "extras": extras,
- "columns": columns,
- "order_desc": True,
- }
-
- result = self.query(prequery_obj)
- prequeries.append(result.query)
- dimensions = [
- c
- for c in result.df.columns
- if c not in metrics and c in groupby_series_columns
- ]
- top_groups = self._get_top_groups(
- result.df, dimensions, groupby_series_columns, columns_by_name
- )
- qry = qry.where(top_groups)
-
- qry = qry.select_from(tbl)
-
- if is_rowcount:
- if not db_engine_spec.allows_subqueries:
- raise QueryObjectValidationError(
- _("Database does not support subqueries")
- )
- label = "rowcount"
- col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
- qry = select([col]).select_from(qry.alias("rowcount_qry"))
- labels_expected = [label]
-
- filter_columns = [flt.get("col") for flt in filter] if filter else []
- rejected_filter_columns = [
- col
- for col in filter_columns
- if col
- and not is_adhoc_column(col)
- and col not in self.column_names
- and col not in applied_template_filters
- ] + rejected_adhoc_filters_columns
- applied_filter_columns = [
- col
- for col in filter_columns
- if col
- and not is_adhoc_column(col)
- and (col in self.column_names or col in applied_template_filters)
- ] + applied_adhoc_filters_columns
-
- return SqlaQuery(
- applied_template_filters=applied_template_filters,
- rejected_filter_columns=rejected_filter_columns,
- applied_filter_columns=applied_filter_columns,
- cte=cte,
- extra_cache_keys=extra_cache_keys,
- labels_expected=labels_expected,
- sqla_query=qry,
- prequeries=prequeries,
- )
-
def _get_series_orderby(
self,
series_limit_metric: Metric,
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 0c52465caa..ad76e0ed85 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""a collection of model-related helper classes and functions"""
# pylint: disable=too-many-lines
+"""a collection of model-related helper classes and functions"""
+import dataclasses
import json
import logging
import re
import uuid
+from collections import defaultdict
from datetime import datetime, timedelta
from json.decoder import JSONDecodeError
from typing import (
Any,
cast,
Dict,
+ Hashable,
List,
- Mapping,
NamedTuple,
Optional,
Set,
@@ -71,6 +73,7 @@ from superset.db_engine_specs.base import TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
AdvancedDataTypeResponseError,
+ ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetSecurityException,
@@ -88,7 +91,14 @@ from superset.superset_typing import (
QueryObjectDict,
)
from superset.utils import core as utils
-from superset.utils.core import get_user_id
+from superset.utils.core import (
+ GenericDataType,
+ get_column_name,
+ get_user_id,
+ is_adhoc_column,
+ remove_duplicates,
+)
+from superset.utils.dates import datetime_to_epoch
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlMetric, TableColumn
@@ -657,6 +667,8 @@ def clone_model(
# todo(hugh): centralize where this code lives
class QueryStringExtended(NamedTuple):
applied_template_filters: Optional[List[str]]
+ applied_filter_columns: List[ColumnTyping]
+ rejected_filter_columns: List[ColumnTyping]
labels_expected: List[str]
prequeries: List[str]
sql: str
@@ -664,6 +676,8 @@ class QueryStringExtended(NamedTuple):
class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
+ applied_filter_columns: List[ColumnTyping]
+ rejected_filter_columns: List[ColumnTyping]
cte: Optional[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
@@ -687,7 +701,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
}
@property
- def query(self) -> str:
+ def fetch_value_predicate(self) -> str:
+ return "fix this!"
+
+ @property
+ def type(self) -> str:
+ raise NotImplementedError()
+
+ @property
+ def db_extra(self) -> Optional[Dict[str, Any]]:
+ raise NotImplementedError()
+
+ def query(self, query_obj: QueryObjectDict) -> QueryResult:
raise NotImplementedError()
@property
@@ -700,7 +725,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
@property
def metrics(self) -> List[Any]:
- raise NotImplementedError()
+ return []
@property
def uid(self) -> str:
@@ -750,17 +775,59 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def columns(self) -> List[Any]:
raise NotImplementedError()
- @property
- def get_fetch_values_predicate(self) -> List[Any]:
+ def get_fetch_values_predicate(
+ self, template_processor: Optional[BaseTemplateProcessor] = None
+ ) -> TextClause:
raise NotImplementedError()
- @staticmethod
- def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
+ def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
raise NotImplementedError()
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
raise NotImplementedError()
+ def get_sqla_row_level_filters(
+ self,
+ template_processor: BaseTemplateProcessor,
+ ) -> List[TextClause]:
+ """
+ Return the appropriate row level security filters for this table and the
+ current user. A custom username can be passed when the user is not present in the
+ Flask global namespace.
+
+ :param template_processor: The template processor to apply to the filters.
+ :returns: A list of SQL clauses to be ANDed together.
+ """
+ all_filters: List[TextClause] = []
+ filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
+ try:
+ for filter_ in security_manager.get_rls_filters(self):
+ clause = self.text(
+ f"({template_processor.process_template(filter_.clause)})"
+ )
+ if filter_.group_key:
+ filter_groups[filter_.group_key].append(clause)
+ else:
+ all_filters.append(clause)
+
+ if is_feature_enabled("EMBEDDED_SUPERSET"):
+ for rule in security_manager.get_guest_rls_filters(self):
+ clause = self.text(
+ f"({template_processor.process_template(rule['clause'])})"
+ )
+ all_filters.append(clause)
+
+ grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
+ all_filters.extend(grouped_filters)
+ return all_filters
+ except TemplateError as ex:
+ raise QueryObjectValidationError(
+ _(
+ "Error in jinja expression in RLS filters: %(msg)s",
+ msg=ex.message,
+ )
+ ) from ex
+
def _process_sql_expression( # pylint: disable=no-self-use
self,
expression: Optional[str],
@@ -859,14 +926,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return ";\n".join(str(statement) for statement in statements)
- def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
+ def get_query_str_extended(
+ self, query_obj: QueryObjectDict, mutate: bool = True
+ ) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
- sql = self.mutate_query_from_config(sql)
+ if mutate:
+ sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
+ applied_filter_columns=sqlaq.applied_filter_columns,
+ rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
@@ -991,9 +1063,16 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
logger.warning(
"Query %s on schema %s failed", sql, self.schema, exc_info=True
)
+ db_engine_spec = self.db_engine_spec
+ errors = [
+ dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex)
+ ]
error_message = utils.error_msg_from_exception(ex)
return QueryResult(
+ applied_template_filters=query_str_ext.applied_template_filters,
+ applied_filter_columns=query_str_ext.applied_filter_columns,
+ rejected_filter_columns=query_str_ext.rejected_filter_columns,
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,
@@ -1063,7 +1142,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
- columns_by_name: Dict[str, "TableColumn"], # # pylint: disable=unused-argument
+ columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
"""
@@ -1163,19 +1242,20 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def _get_series_orderby(
self,
series_limit_metric: Metric,
- metrics_by_name: Mapping[str, "SqlMetric"],
- columns_by_name: Mapping[str, "TableColumn"],
+ metrics_by_name: Dict[str, "SqlMetric"],
+ columns_by_name: Dict[str, "TableColumn"],
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
- ob = self.adhoc_metric_to_sqla(
- series_limit_metric, columns_by_name # type: ignore
- )
+ ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name)
elif (
isinstance(series_limit_metric, str)
and series_limit_metric in metrics_by_name
):
- ob = metrics_by_name[series_limit_metric].get_sqla_col()
+ ob = metrics_by_name[series_limit_metric].get_sqla_col(
+ template_processor=template_processor
+ )
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
@@ -1184,26 +1264,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def adhoc_column_to_sqla(
self,
- col: Type["AdhocColumn"], # type: ignore
+ col: "AdhocColumn", # type: ignore
+ force_type_check: bool = False,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
- """
- Turn an adhoc column into a sqlalchemy column.
-
- :param col: Adhoc column definition
- :param template_processor: template_processor instance
- :returns: The metric defined as a sqlalchemy column
- :rtype: sqlalchemy.sql.column
- """
- label = utils.get_column_name(col) # type: ignore
- expression = self._process_sql_expression(
- expression=col["sqlExpression"],
- database_id=self.database_id,
- schema=self.schema,
- template_processor=template_processor,
- )
- sqla_column = literal_column(expression)
- return self.make_sqla_column_compatible(sqla_column, label)
+ raise NotImplementedError()
def _get_top_groups(
self,
@@ -1241,29 +1306,30 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return f'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'
- def get_time_filter(
+ def get_time_filter( # pylint: disable=too-many-arguments
self,
- time_col: Dict[str, Any],
+ time_col: "TableColumn",
start_dttm: Optional[sa.DateTime],
end_dttm: Optional[sa.DateTime],
+ label: Optional[str] = "__time",
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
- label = "__time"
- col = time_col.get("column_name")
- sqla_col = literal_column(col)
- my_col = self.make_sqla_column_compatible(sqla_col, label)
+ col = self.convert_tbl_column_to_sqla_col(
+ time_col, label=label, template_processor=template_processor
+ )
l = []
if start_dttm:
l.append(
- my_col
+ col
>= self.db_engine_spec.get_text_clause(
- self.dttm_sql_literal(start_dttm, time_col.get("type"))
+ self.dttm_sql_literal(start_dttm, time_col.type)
)
)
if end_dttm:
l.append(
- my_col
+ col
< self.db_engine_spec.get_text_clause(
- self.dttm_sql_literal(end_dttm, time_col.get("type"))
+ self.dttm_sql_literal(end_dttm, time_col.type)
)
)
return and_(*l)
@@ -1327,11 +1393,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain)
return self.make_sqla_column_compatible(time_expr, label)
- def get_sqla_col(self, col: Dict[str, Any]) -> Column:
- label = col.get("column_name")
- col_type = col.get("type")
- col = sa.column(label, type_=col_type)
- return self.make_sqla_column_compatible(col, label)
+ def convert_tbl_column_to_sqla_col(
+ self,
+ tbl_column: "TableColumn",
+ label: Optional[str] = None,
+ template_processor: Optional[BaseTemplateProcessor] = None,
+ ) -> Column:
+ label = label or tbl_column.column_name
+ db_engine_spec = self.db_engine_spec
+ column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
+ type_ = column_spec.sqla_type if column_spec else None
+ if expression := tbl_column.expression:
+ if template_processor:
+ expression = template_processor.process_template(expression)
+ col = literal_column(expression, type_=type_)
+ else:
+ col = sa.column(tbl_column.column_name, type_=type_)
+ col = self.make_sqla_column_compatible(col, label)
+ return col
def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
self,
@@ -1378,11 +1457,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"time_column": granularity,
"time_grain": time_grain,
"to_dttm": to_dttm.isoformat() if to_dttm else None,
- "table_columns": [col.get("column_name") for col in self.columns],
+ "table_columns": [col.column_name for col in self.columns],
"filter": filter,
}
columns = columns or []
groupby = groupby or []
+ rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
+ applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
series_column_names = utils.get_column_names(series_columns or [])
# deprecated, to be removed in 2.0
if is_timeseries and timeseries_limit:
@@ -1407,8 +1488,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
granularity = self.main_dttm_col
columns_by_name: Dict[str, "TableColumn"] = {
- col.get("column_name"): col
- for col in self.columns # col.column_name: col for col in self.columns
+ col.column_name: col for col in self.columns
+ }
+
+ metrics_by_name: Dict[str, "SqlMetric"] = {
+ m.metric_name: m for m in self.metrics
}
if not granularity and is_timeseries:
@@ -1432,6 +1516,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
template_processor=template_processor,
)
)
+ elif isinstance(metric, str) and metric in metrics_by_name:
+ metrics_exprs.append(
+ metrics_by_name[metric].get_sqla_col(
+ template_processor=template_processor
+ )
+ )
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
@@ -1470,14 +1560,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in columns_by_name:
- gb_column_obj = columns_by_name[col]
- if isinstance(gb_column_obj, dict):
- col = self.get_sqla_col(gb_column_obj)
- else:
- col = gb_column_obj.get_sqla_col()
+ col = self.convert_tbl_column_to_sqla_col(
+ columns_by_name[col], template_processor=template_processor
+ )
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
+ elif col in metrics_by_name:
+ col = metrics_by_name[col].get_sqla_col(
+ template_processor=template_processor
+ )
+ need_groupby = True
if isinstance(col, ColumnElement):
orderby_exprs.append(col)
@@ -1503,33 +1596,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# if groupby field/expr equals granularity field/expr
if selected == granularity:
table_col = columns_by_name[selected]
- if isinstance(table_col, dict):
- outer = self.get_timestamp_expression(
- column=table_col,
- time_grain=time_grain,
- label=selected,
- template_processor=template_processor,
- )
- else:
- outer = table_col.get_timestamp_expression(
- time_grain=time_grain,
- label=selected,
- template_processor=template_processor,
- )
+ outer = table_col.get_timestamp_expression(
+ time_grain=time_grain,
+ label=selected,
+ template_processor=template_processor,
+ )
# if groupby field equals a selected column
elif selected in columns_by_name:
- if isinstance(columns_by_name[selected], dict):
- outer = sa.column(f"{selected}")
- outer = self.make_sqla_column_compatible(outer, selected)
- else:
- outer = columns_by_name[selected].get_sqla_col()
+ outer = self.convert_tbl_column_to_sqla_col(
+ columns_by_name[selected],
+ template_processor=template_processor,
+ )
else:
- selected = self.validate_adhoc_subquery(
+ selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
- outer = sa.column(f"{selected}")
+ outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
outer = self.adhoc_column_to_sqla(
@@ -1543,19 +1627,28 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
select_exprs.append(outer)
elif columns:
for selected in columns:
- selected = self.validate_adhoc_subquery(
- selected,
+ if is_adhoc_column(selected):
+ _sql = selected["sqlExpression"]
+ _column_label = selected["label"]
+ elif isinstance(selected, str):
+ _sql = selected
+ _column_label = selected
+
+ selected = validate_adhoc_subquery(
+ _sql,
self.database_id,
self.schema,
)
- if isinstance(columns_by_name[selected], dict):
- select_exprs.append(sa.column(f"{selected}"))
- else:
- select_exprs.append(
- columns_by_name[selected].get_sqla_col()
- if selected in columns_by_name
- else self.make_sqla_column_compatible(literal_column(selected))
+
+ select_exprs.append(
+ self.convert_tbl_column_to_sqla_col(
+ columns_by_name[selected], template_processor=template_processor
)
+ if isinstance(selected, str) and selected in columns_by_name
+ else self.make_sqla_column_compatible(
+ literal_column(selected), _column_label
+ )
+ )
metrics_exprs = []
if granularity:
@@ -1566,57 +1659,43 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
col=granularity,
)
)
- time_filters: List[Any] = []
+ time_filters = []
if is_timeseries:
- if isinstance(dttm_col, dict):
- timestamp = self.get_timestamp_expression(
- dttm_col, time_grain, template_processor=template_processor
- )
- else:
- timestamp = dttm_col.get_timestamp_expression(
- time_grain=time_grain, template_processor=template_processor
- )
+ timestamp = dttm_col.get_timestamp_expression(
+ time_grain=time_grain, template_processor=template_processor
+ )
# always put timestamp as the first column
select_exprs.insert(0, timestamp)
groupby_all_columns[timestamp.name] = timestamp
# Use main dttm column to support index with secondary dttm columns.
- if db_engine_spec.time_secondary_columns:
- if isinstance(dttm_col, dict):
- dttm_col_name = dttm_col.get("column_name")
- else:
- dttm_col_name = dttm_col.column_name
-
- if (
- self.main_dttm_col in self.dttm_cols
- and self.main_dttm_col != dttm_col_name
- ):
- if isinstance(self.main_dttm_col, dict):
- time_filters.append(
- self.get_time_filter(
- self.main_dttm_col,
- from_dttm,
- to_dttm,
- )
- )
- else:
- time_filters.append(
- columns_by_name[self.main_dttm_col].get_time_filter(
- from_dttm,
- to_dttm,
- )
- )
+ if (
+ db_engine_spec.time_secondary_columns
+ and self.main_dttm_col in self.dttm_cols
+ and self.main_dttm_col != dttm_col.column_name
+ ):
+ time_filters.append(
+ self.get_time_filter(
+ time_col=columns_by_name[self.main_dttm_col],
+ start_dttm=from_dttm,
+ end_dttm=to_dttm,
+ template_processor=template_processor,
+ )
+ )
- if isinstance(dttm_col, dict):
- time_filters.append(self.get_time_filter(dttm_col, from_dttm, to_dttm))
- else:
- time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
+ time_filter_column = self.get_time_filter(
+ time_col=dttm_col,
+ start_dttm=from_dttm,
+ end_dttm=to_dttm,
+ template_processor=template_processor,
+ )
+ time_filters.append(time_filter_column)
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
# raw columns as custom SQL adhoc metric).
- select_exprs = utils.remove_duplicates(
+ select_exprs = remove_duplicates(
select_exprs + metrics_exprs, key=lambda x: x.name
)
@@ -1626,7 +1705,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# Order by columns are "hidden" columns, some databases require them
# always be present in SELECT if an aggregation function is used
if not db_engine_spec.allows_hidden_orderby_agg:
- select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs)
+ select_exprs = remove_duplicates(select_exprs + orderby_exprs)
qry = sa.select(select_exprs)
@@ -1648,14 +1727,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
sqla_col: Optional[Column] = None
if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
col_obj = dttm_col
- elif utils.is_adhoc_column(flt_col):
- sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore
+ elif is_adhoc_column(flt_col):
+ try:
+ sqla_col = self.adhoc_column_to_sqla(flt_col, force_type_check=True)
+ applied_adhoc_filters_columns.append(flt_col)
+ except ColumnNotFoundException:
+ rejected_adhoc_filters_columns.append(flt_col)
+ continue
else:
col_obj = columns_by_name.get(cast(str, flt_col))
filter_grain = flt.get("grain")
if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
- if utils.get_column_name(flt_col) in removed_filters:
+ if get_column_name(flt_col) in removed_filters:
# Skip generating SQLA filter when the jinja template handles it.
continue
@@ -1663,44 +1747,29 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if sqla_col is not None:
pass
elif col_obj and filter_grain:
- if isinstance(col_obj, dict):
- sqla_col = self.get_timestamp_expression(
- col_obj, time_grain, template_processor=template_processor
- )
- else:
- sqla_col = col_obj.get_timestamp_expression(
- time_grain=filter_grain,
- template_processor=template_processor,
- )
- elif col_obj and isinstance(col_obj, dict):
- sqla_col = sa.column(col_obj.get("column_name"))
+ sqla_col = col_obj.get_timestamp_expression(
+ time_grain=filter_grain, template_processor=template_processor
+ )
elif col_obj:
- sqla_col = col_obj.get_sqla_col()
-
- if col_obj and isinstance(col_obj, dict):
- col_type = col_obj.get("type")
- else:
- col_type = col_obj.type if col_obj else None
+ sqla_col = self.convert_tbl_column_to_sqla_col(
+ tbl_column=col_obj, template_processor=template_processor
+ )
+ col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(
native_type=col_type,
- db_extra=self.database.get_extra(), # type: ignore
+ # db_extra=self.database.get_extra(),
)
is_list_target = op in (
utils.FilterOperator.IN.value,
utils.FilterOperator.NOT_IN.value,
)
- if col_obj and isinstance(col_obj, dict):
- col_advanced_data_type = ""
- else:
- col_advanced_data_type = (
- col_obj.advanced_data_type if col_obj else ""
- )
+ col_advanced_data_type = col_obj.advanced_data_type if col_obj else ""
if col_spec and not col_advanced_data_type:
target_generic_type = col_spec.generic_type
else:
- target_generic_type = utils.GenericDataType.STRING
+ target_generic_type = GenericDataType.STRING
eq = self.filter_values_handler(
values=val,
operator=op,
@@ -1708,7 +1777,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
target_native_type=col_type,
is_list_target=is_list_target,
db_engine_spec=db_engine_spec,
- db_extra=self.database.get_extra(), # type: ignore
+ # db_extra=self.database.get_extra(),
)
if (
col_advanced_data_type != ""
@@ -1764,7 +1833,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
elif op == utils.FilterOperator.IS_FALSE.value:
where_clause_and.append(sqla_col.is_(False))
else:
- if eq is None:
+ if (
+ op
+ not in {
+ utils.FilterOperator.EQUALS.value,
+ utils.FilterOperator.NOT_EQUALS.value,
+ }
+ and eq is None
+ ):
raise QueryObjectValidationError(
_(
"Must specify a value for filters "
@@ -1802,19 +1878,20 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
time_col=col_obj,
start_dttm=_since,
end_dttm=_until,
+ label=sqla_col.key,
+ template_processor=template_processor,
)
)
else:
raise QueryObjectValidationError(
_("Invalid filter operation type: %(op)s", op=op)
)
- # todo(hugh): fix this w/ template_processor
- # where_clause_and += self.get_sqla_row_level_filters(template_processor)
+ where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
try:
- where = template_processor.process_template(f"{where}")
+ where = template_processor.process_template(f"({where})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@@ -1822,11 +1899,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
msg=ex.message,
)
) from ex
+ where = self._process_sql_expression(
+ expression=where,
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor,
+ )
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
try:
- having = template_processor.process_template(f"{having}")
+ having = template_processor.process_template(f"({having})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@@ -1834,9 +1917,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
msg=ex.message,
)
) from ex
+ having = self._process_sql_expression(
+ expression=having,
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor,
+ )
having_clause_and += [self.text(having)]
+
if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore
- qry = qry.where(self.get_fetch_values_predicate()) # type: ignore
+ qry = qry.where(
+ self.get_fetch_values_predicate(template_processor=template_processor)
+ )
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and)))
else:
@@ -1876,7 +1968,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
inner_groupby_exprs = []
inner_select_exprs = []
for gby_name, gby_obj in groupby_series_columns.items():
- label = utils.get_column_name(gby_name)
+ label = get_column_name(gby_name)
inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
@@ -1886,26 +1978,25 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
inner_time_filter = []
if dttm_col and not db_engine_spec.time_groupby_inline:
- if isinstance(dttm_col, dict):
- inner_time_filter = [
- self.get_time_filter(
- dttm_col,
- inner_from_dttm or from_dttm,
- inner_to_dttm or to_dttm,
- )
- ]
- else:
- inner_time_filter = [
- dttm_col.get_time_filter(
- inner_from_dttm or from_dttm,
- inner_to_dttm or to_dttm,
- )
- ]
-
+ inner_time_filter = [
+ self.get_time_filter(
+ time_col=dttm_col,
+ start_dttm=inner_from_dttm or from_dttm,
+ end_dttm=inner_to_dttm or to_dttm,
+ template_processor=template_processor,
+ )
+ ]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
subq = subq.group_by(*inner_groupby_exprs)
ob = inner_main_metric_expr
+ if series_limit_metric:
+ ob = self._get_series_orderby(
+ series_limit_metric=series_limit_metric,
+ metrics_by_name=metrics_by_name,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
+ )
direction = sa.desc if order_desc else sa.asc
subq = subq.order_by(direction(ob))
subq = subq.limit(series_limit)
@@ -1919,6 +2010,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
on_clause.append(gby_obj == sa.column(col_name))
tbl = tbl.join(subq.alias(), and_(*on_clause))
+ else:
+ if series_limit_metric:
+ orderby = [
+ (
+ self._get_series_orderby(
+ series_limit_metric=series_limit_metric,
+ metrics_by_name=metrics_by_name,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
+ ),
+ not order_desc,
+ )
+ ]
# run prequery to get top groups
prequery_obj = {
@@ -1935,7 +2039,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"columns": columns,
"order_desc": True,
}
- result = self.exc_query(prequery_obj)
+
+ result = self.query(prequery_obj)
prequeries.append(result.query)
dimensions = [
c
@@ -1959,9 +2064,29 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
qry = sa.select([col]).select_from(qry.alias("rowcount_qry"))
labels_expected = [label]
+ filter_columns = [flt.get("col") for flt in filter] if filter else []
+ rejected_filter_columns = [
+ col
+ for col in filter_columns
+ if col
+ and not is_adhoc_column(col)
+ and col not in self.column_names
+ and col not in applied_template_filters
+ ] + rejected_adhoc_filters_columns
+
+ applied_filter_columns = [
+ col
+ for col in filter_columns
+ if col
+ and not is_adhoc_column(col)
+ and (col in self.column_names or col in applied_template_filters)
+ ] + applied_adhoc_filters_columns
+
return SqlaQuery(
applied_template_filters=applied_template_filters,
cte=cte,
+ applied_filter_columns=applied_filter_columns,
+ rejected_filter_columns=rejected_filter_columns,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 11ea7ef528..e7d9f672e0 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -19,7 +19,7 @@ import inspect
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
+from typing import Any, Dict, Hashable, List, Optional, Type, TYPE_CHECKING
import simplejson as json
import sqlalchemy as sqla
@@ -52,9 +52,10 @@ from superset.models.helpers import (
)
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sqllab.limiting_factor import LimitingFactor
-from superset.utils.core import GenericDataType, QueryStatus, user_label
+from superset.utils.core import QueryStatus, user_label
if TYPE_CHECKING:
+ from superset.connectors.sqla.models import TableColumn
from superset.db_engine_specs import BaseEngineSpec
@@ -183,47 +184,33 @@ class Query(
return list(ParsedQuery(self.sql).tables)
@property
- def columns(self) -> List[Dict[str, Any]]:
- bool_types = ("BOOL",)
- num_types = (
- "DOUBLE",
- "FLOAT",
- "INT",
- "BIGINT",
- "NUMBER",
- "LONG",
- "REAL",
- "NUMERIC",
- "DECIMAL",
- "MONEY",
+ def columns(self) -> List["TableColumn"]:
+ from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
+ TableColumn,
)
- date_types = ("DATE", "TIME")
- str_types = ("VARCHAR", "STRING", "CHAR")
+
columns = []
- col_type = ""
for col in self.extra.get("columns", []):
- computed_column = {**col}
- col_type = col.get("type")
-
- if col_type and any(map(lambda t: t in col_type.upper(), str_types)):
- computed_column["type_generic"] = GenericDataType.STRING
- if col_type and any(map(lambda t: t in col_type.upper(), bool_types)):
- computed_column["type_generic"] = GenericDataType.BOOLEAN
- if col_type and any(map(lambda t: t in col_type.upper(), num_types)):
- computed_column["type_generic"] = GenericDataType.NUMERIC
- if col_type and any(map(lambda t: t in col_type.upper(), date_types)):
- computed_column["type_generic"] = GenericDataType.TEMPORAL
-
- computed_column["column_name"] = col.get("name")
- computed_column["groupby"] = True
- columns.append(computed_column)
+ columns.append(
+ TableColumn(
+ column_name=col["name"],
+ type=col["type"],
+ is_dttm=col["is_dttm"],
+ groupby=True,
+ filterable=True,
+ )
+ )
return columns
+ @property
+ def db_extra(self) -> Optional[Dict[str, Any]]:
+ return None
+
@property
def data(self) -> Dict[str, Any]:
order_by_choices = []
for col in self.columns:
- column_name = str(col.get("column_name") or "")
+ column_name = str(col.column_name or "")
order_by_choices.append(
(json.dumps([column_name, True]), f"{column_name} " + __("[asc]"))
)
@@ -237,7 +224,7 @@ class Query(
],
"filter_select": True,
"name": self.tab_name,
- "columns": self.columns,
+ "columns": [o.data for o in self.columns],
"metrics": [],
"id": self.id,
"type": self.type,
@@ -279,7 +266,7 @@ class Query(
@property
def column_names(self) -> List[Any]:
- return [col.get("column_name") for col in self.columns]
+ return [col.column_name for col in self.columns]
@property
def offset(self) -> int:
@@ -294,7 +281,7 @@ class Query(
@property
def dttm_cols(self) -> List[Any]:
- return [col.get("column_name") for col in self.columns if col.get("is_dttm")]
+ return [col.column_name for col in self.columns if col.is_dttm]
@property
def schema_perm(self) -> str:
@@ -309,7 +296,7 @@ class Query(
return ""
@staticmethod
- def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
+ def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]:
return []
@property
@@ -337,7 +324,7 @@ class Query(
if not column_name:
return None
for col in self.columns:
- if col.get("column_name") == column_name:
+ if col.column_name == column_name:
return col
return None
diff --git a/superset/utils/core.py b/superset/utils/core.py
index d229942834..ee84d52eb9 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1785,15 +1785,9 @@ def get_time_filter_status(
datasource: "BaseDatasource",
applied_time_extras: Dict[str, str],
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
- temporal_columns: Set[Any]
- if datasource.type == "query":
- temporal_columns = {
- col.get("column_name") for col in datasource.columns if col.get("is_dttm")
- }
- else:
- temporal_columns = {
- col.column_name for col in datasource.columns if col.is_dttm
- }
+ temporal_columns: Set[Any] = {
+ col.column_name for col in datasource.columns if col.is_dttm
+ }
applied: List[Dict[str, str]] = []
rejected: List[Dict[str, str]] = []
time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL)
diff --git a/superset/views/core.py b/superset/views/core.py
index 44f1b78af0..7db64279a3 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -2021,7 +2021,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
db.session.add(table)
cols = []
for config_ in data.get("columns"):
- column_name = config_.get("name")
+ column_name = config_.get("column_name") or config_.get("name")
col = TableColumn(
column_name=column_name,
filterable=True,
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index ed8de062d7..231f06598e 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -1186,8 +1186,8 @@ def test_chart_cache_timeout_chart_not_found(
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
- (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
- (400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
+ (403, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
+ (403, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index d9f26239d1..27ccdde96b 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -493,8 +493,16 @@ class TestSqlLab(SupersetTestCase):
"datasourceName": f"test_viz_flow_table_{random()}",
"schema": "superset",
"columns": [
- {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"},
- {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"},
+ {
+ "is_dttm": False,
+ "type": "STRING",
+ "column_name": f"viz_type_{random()}",
+ },
+ {
+ "is_dttm": False,
+ "type": "OBJECT",
+ "column_name": f"ccount_{random()}",
+ },
],
"sql": """\
SELECT *
@@ -523,8 +531,16 @@ class TestSqlLab(SupersetTestCase):
"chartType": "dist_bar",
"schema": "superset",
"columns": [
- {"is_dttm": False, "type": "STRING", "name": f"viz_type_{random()}"},
- {"is_dttm": False, "type": "OBJECT", "name": f"ccount_{random()}"},
+ {
+ "is_dttm": False,
+ "type": "STRING",
+ "column_name": f"viz_type_{random()}",
+ },
+ {
+ "is_dttm": False,
+ "type": "OBJECT",
+ "column_name": f"ccount_{random()}",
+ },
],
"sql": """\
SELECT *