You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/12/03 04:19:32 UTC
[superset] branch master updated: fix(sqla): use same template processor in all methods (#22280)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 1ad5147016 fix(sqla): use same template processor in all methods (#22280)
1ad5147016 is described below
commit 1ad514701609785f19b27ad495ba34f3b9fff585
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Sat Dec 3 06:19:25 2022 +0200
fix(sqla): use same template processor in all methods (#22280)
---
superset/connectors/sqla/models.py | 152 +++++++++++++++++++--------
tests/integration_tests/sqla_models_tests.py | 43 ++++++--
2 files changed, 139 insertions(+), 56 deletions(-)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index a67e686ff3..fd5942c51a 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
+from __future__ import annotations
+
import dataclasses
import json
import logging
@@ -222,7 +224,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
- table: "SqlaTable" = relationship(
+ table: SqlaTable = relationship(
"SqlaTable",
backref=backref("columns", cascade="all, delete-orphan"),
foreign_keys=[table_id],
@@ -301,14 +303,18 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
)
return column_spec.generic_type if column_spec else None
- def get_sqla_col(self, label: Optional[str] = None) -> Column:
+ def get_sqla_col(
+ self,
+ label: Optional[str] = None,
+ template_processor: Optional[BaseTemplateProcessor] = None,
+ ) -> Column:
label = label or self.column_name
db_engine_spec = self.db_engine_spec
column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
type_ = column_spec.sqla_type if column_spec else None
- if self.expression:
- tp = self.table.get_template_processor()
- expression = tp.process_template(self.expression)
+ if expression := self.expression:
+ if template_processor:
+ expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
@@ -324,8 +330,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
start_dttm: Optional[DateTime] = None,
end_dttm: Optional[DateTime] = None,
label: Optional[str] = "__time",
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
- col = self.get_sqla_col(label=label)
+ col = self.get_sqla_col(label=label, template_processor=template_processor)
l = []
if start_dttm:
l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm)))
@@ -358,10 +365,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
if not self.expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=type_)
return self.table.make_sqla_column_compatible(sqla_col, label)
- if self.expression:
- expression = self.expression
+ if expression := self.expression:
if template_processor:
- expression = template_processor.process_template(self.expression)
+ expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
@@ -458,10 +464,17 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
def __repr__(self) -> str:
return str(self.metric_name)
- def get_sqla_col(self, label: Optional[str] = None) -> Column:
+ def get_sqla_col(
+ self,
+ label: Optional[str] = None,
+ template_processor: Optional[BaseTemplateProcessor] = None,
+ ) -> Column:
label = label or self.metric_name
- tp = self.table.get_template_processor()
- sqla_col: ColumnClause = literal_column(tp.process_template(self.expression))
+ expression = self.expression
+ if template_processor:
+ expression = template_processor.process_template(expression)
+
+ sqla_col: ColumnClause = literal_column(expression)
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
@@ -650,7 +663,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
datasource_name: str,
schema: Optional[str],
database_name: str,
- ) -> Optional["SqlaTable"]:
+ ) -> Optional[SqlaTable]:
schema = schema or None
query = (
session.query(cls)
@@ -778,10 +791,17 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
except (TypeError, json.JSONDecodeError):
return {}
- def get_fetch_values_predicate(self) -> TextClause:
- tp = self.get_template_processor()
+ def get_fetch_values_predicate(
+ self,
+ template_processor: Optional[BaseTemplateProcessor] = None,
+ ) -> TextClause:
+ fetch_values_predicate = self.fetch_values_predicate
+ if template_processor:
+ fetch_values_predicate = template_processor.process_template(
+ fetch_values_predicate
+ )
try:
- return self.text(tp.process_template(self.fetch_values_predicate))
+ return self.text(fetch_values_predicate)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@@ -799,12 +819,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)
- qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
+ qry = (
+ select([target_col.get_sqla_col(template_processor=tp)])
+ .select_from(tbl)
+ .distinct()
+ )
if limit:
qry = qry.limit(limit)
if self.fetch_values_predicate:
- qry = qry.where(self.get_fetch_values_predicate())
+ qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
@@ -936,7 +960,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
column_name = cast(str, metric_column.get("column_name"))
table_column: Optional[TableColumn] = columns_by_name.get(column_name)
if table_column:
- sqla_column = table_column.get_sqla_col()
+ sqla_column = table_column.get_sqla_col(
+ template_processor=template_processor
+ )
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
@@ -975,7 +1001,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
col_in_metadata = self.get_column(expression)
if col_in_metadata:
- sqla_column = col_in_metadata.get_sqla_col()
+ sqla_column = col_in_metadata.get_sqla_col(
+ template_processor=template_processor
+ )
is_dttm = col_in_metadata.is_temporal
else:
sqla_column = literal_column(expression)
@@ -1190,7 +1218,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
)
elif isinstance(metric, str) and metric in metrics_by_name:
- metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
+ metrics_exprs.append(
+ metrics_by_name[metric].get_sqla_col(
+ template_processor=template_processor
+ )
+ )
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
@@ -1229,12 +1261,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in columns_by_name:
- col = columns_by_name[col].get_sqla_col()
+ col = columns_by_name[col].get_sqla_col(
+ template_processor=template_processor
+ )
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
elif col in metrics_by_name:
- col = metrics_by_name[col].get_sqla_col()
+ col = metrics_by_name[col].get_sqla_col(
+ template_processor=template_processor
+ )
need_groupby = True
if isinstance(col, ColumnElement):
@@ -1268,7 +1304,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
# if groupby field equals a selected column
elif selected in columns_by_name:
- outer = columns_by_name[selected].get_sqla_col()
+ outer = columns_by_name[selected].get_sqla_col(
+ template_processor=template_processor
+ )
else:
selected = validate_adhoc_subquery(
selected,
@@ -1302,7 +1340,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
self.schema,
)
select_exprs.append(
- columns_by_name[selected].get_sqla_col()
+ columns_by_name[selected].get_sqla_col(
+ template_processor=template_processor
+ )
if isinstance(selected, str) and selected in columns_by_name
else self.make_sqla_column_compatible(
literal_column(selected), _column_label
@@ -1336,11 +1376,18 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
):
time_filters.append(
columns_by_name[self.main_dttm_col].get_time_filter(
- from_dttm,
- to_dttm,
+ start_dttm=from_dttm,
+ end_dttm=to_dttm,
+ template_processor=template_processor,
)
)
- time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
+ time_filters.append(
+ dttm_col.get_time_filter(
+ start_dttm=from_dttm,
+ end_dttm=to_dttm,
+ template_processor=template_processor,
+ )
+ )
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
@@ -1396,7 +1443,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
time_grain=filter_grain, template_processor=template_processor
)
elif col_obj:
- sqla_col = col_obj.get_sqla_col()
+ sqla_col = col_obj.get_sqla_col(
+ template_processor=template_processor
+ )
col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(
native_type=col_type,
@@ -1521,6 +1570,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
start_dttm=_since,
end_dttm=_until,
label=sqla_col.key,
+ template_processor=template_processor,
)
)
else:
@@ -1565,7 +1615,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
- qry = qry.where(self.get_fetch_values_predicate())
+ qry = qry.where(
+ self.get_fetch_values_predicate(template_processor=template_processor)
+ )
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and)))
else:
@@ -1617,8 +1669,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if dttm_col and not db_engine_spec.time_groupby_inline:
inner_time_filter = [
dttm_col.get_time_filter(
- inner_from_dttm or from_dttm,
- inner_to_dttm or to_dttm,
+ start_dttm=inner_from_dttm or from_dttm,
+ end_dttm=inner_to_dttm or to_dttm,
+ template_processor=template_processor,
)
]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
@@ -1627,7 +1680,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
ob = inner_main_metric_expr
if series_limit_metric:
ob = self._get_series_orderby(
- series_limit_metric, metrics_by_name, columns_by_name
+ series_limit_metric=series_limit_metric,
+ metrics_by_name=metrics_by_name,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
)
direction = desc if order_desc else asc
subq = subq.order_by(direction(ob))
@@ -1647,9 +1703,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
orderby = [
(
self._get_series_orderby(
- series_limit_metric,
- metrics_by_name,
- columns_by_name,
+ series_limit_metric=series_limit_metric,
+ metrics_by_name=metrics_by_name,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
),
not order_desc,
)
@@ -1709,6 +1766,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
series_limit_metric: Metric,
metrics_by_name: Dict[str, SqlMetric],
columns_by_name: Dict[str, TableColumn],
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
@@ -1717,7 +1775,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
isinstance(series_limit_metric, str)
and series_limit_metric in metrics_by_name
):
- ob = metrics_by_name[series_limit_metric].get_sqla_col()
+ ob = metrics_by_name[series_limit_metric].get_sqla_col(
+ template_processor=template_processor
+ )
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
@@ -1930,7 +1990,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
database: Database,
datasource_name: str,
schema: Optional[str] = None,
- ) -> List["SqlaTable"]:
+ ) -> List[SqlaTable]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@@ -1947,7 +2007,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
database: Database,
permissions: Set[str],
schema_perms: Set[str],
- ) -> List["SqlaTable"]:
+ ) -> List[SqlaTable]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
@@ -1964,7 +2024,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
- ) -> "SqlaTable":
+ ) -> SqlaTable:
"""Returns SqlaTable with columns and metrics."""
return (
session.query(cls)
@@ -1977,7 +2037,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
@classmethod
- def get_all_datasources(cls, session: Session) -> List["SqlaTable"]:
+ def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@@ -2038,7 +2098,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def before_update(
mapper: Mapper, # pylint: disable=unused-argument
connection: Connection, # pylint: disable=unused-argument
- target: "SqlaTable",
+ target: SqlaTable,
) -> None:
"""
Check before update if the target table already exists.
@@ -2110,7 +2170,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def after_insert(
mapper: Mapper,
connection: Connection,
- sqla_table: "SqlaTable",
+ sqla_table: SqlaTable,
) -> None:
"""
Update dataset permissions after insert
@@ -2124,7 +2184,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def after_delete(
mapper: Mapper,
connection: Connection,
- sqla_table: "SqlaTable",
+ sqla_table: SqlaTable,
) -> None:
"""
Update dataset permissions after delete
@@ -2135,7 +2195,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def after_update(
mapper: Mapper,
connection: Connection,
- sqla_table: "SqlaTable",
+ sqla_table: SqlaTable,
) -> None:
"""
Update dataset permissions
@@ -2170,7 +2230,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return
def write_shadow_dataset(
- self: "SqlaTable",
+ self: SqlaTable,
) -> None:
"""
This method is deprecated
diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py
index 3088bdfb02..dfba16179d 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -201,22 +201,34 @@ class TestDatabaseModel(SupersetTestCase):
"granularity": None,
"from_dttm": None,
"to_dttm": None,
- "groupby": ["user", "expr"],
+ "columns": [
+ "user",
+ "expr",
+ {
+ "hasCustomLabel": True,
+ "label": "adhoc_column",
+ "sqlExpression": "'{{ 'foo_' + time_grain }}'",
+ },
+ ],
"metrics": [
{
+ "hasCustomLabel": True,
+ "label": "adhoc_metric",
"expressionType": AdhocMetricExpressionType.SQL,
- "sqlExpression": "SUM(case when user = '{{ current_username() }}' "
- "then 1 else 0 end)",
- "label": "SUM(userid)",
- }
+ "sqlExpression": "SUM(case when user = '{{ 'user_' + "
+ "current_username() }}' then 1 else 0 end)",
+ },
+ "count_timegrain",
],
"is_timeseries": False,
"filter": [],
+ "extras": {"time_grain_sqla": "P1D"},
}
table = SqlaTable(
table_name="test_has_jinja_metric_and_expr",
- sql="SELECT '{{ current_username() }}' as user",
+ sql="SELECT '{{ 'user_' + current_username() }}' as user, "
+ "'{{ 'xyz_' + time_grain }}' as time_grain",
database=get_example_database(),
)
TableColumn(
@@ -226,14 +238,25 @@ class TestDatabaseModel(SupersetTestCase):
type="VARCHAR(100)",
table=table,
)
+ SqlMetric(
+ metric_name="count_timegrain",
+ expression="count('{{ 'bar_' + time_grain }}')",
+ table=table,
+ )
db.session.commit()
sqla_query = table.get_sqla_query(**base_query_obj)
query = table.database.compile_sqla_query(sqla_query.sqla_query)
- # assert expression
- assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
- # assert metric
- assert "SUM(case when user = 'abc' then 1 else 0 end)" in query
+ # assert virtual dataset
+ assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query
+ # assert dataset calculated column
+ assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query
+ # assert adhoc column
+ assert "'foo_P1D'" in query
+ # assert dataset saved metric
+ assert "count('bar_P1D')" in query
+ # assert adhoc metric
+ assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query
# Cleanup
db.session.delete(table)
db.session.commit()