You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/04/08 13:48:02 UTC
[superset] 05/09: fix(sqla): apply jinja to metrics (#19565)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 44eb81e35f769b2f3f3224a3c8e6b2045e48e5cc
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Thu Apr 7 14:04:51 2022 +0300
fix(sqla): apply jinja to metrics (#19565)
(cherry picked from commit 34b55765c4b0cbd8f0b9f89c6ca0f62f4478270e)
---
superset/connectors/sqla/models.py | 83 ++++++++++++++++++++++++--------------
1 file changed, 52 insertions(+), 31 deletions(-)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 8721f6ea81..b8d3a7d091 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -354,6 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
:param time_grain: Optional time grain, e.g. P1Y
:param label: alias/label that column is expected to have
+ :param template_processor: template processor
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = label or utils.DTTM_ALIAS
@@ -517,6 +518,27 @@ sqlatable_user = Table(
)
+def _process_sql_expression(
+ expression: Optional[str],
+ database_id: int,
+ schema: str,
+ template_processor: Optional[BaseTemplateProcessor],
+) -> Optional[str]:
+ if template_processor and expression:
+ expression = template_processor.process_template(expression)
+ if expression:
+ expression = validate_adhoc_subquery(
+ expression,
+ database_id,
+ schema,
+ )
+ try:
+ expression = sanitize_clause(expression)
+ except QueryClauseValidationException as ex:
+ raise QueryObjectValidationError(ex.message) from ex
+ return expression
+
+
class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""
@@ -899,13 +921,17 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return sql
def adhoc_metric_to_sqla(
- self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn]
+ self,
+ metric: AdhocMetric,
+ columns_by_name: Dict[str, TableColumn],
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition
:param dict columns_by_name: Columns for the current table
+ :param template_processor: template_processor instance
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
@@ -922,17 +948,12 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.AdhocMetricExpressionType.SQL:
- tp = self.get_template_processor()
- expression = tp.process_template(cast(str, metric["sqlExpression"]))
- expression = validate_adhoc_subquery(
- expression,
- self.database_id,
- self.schema,
+ expression = _process_sql_expression(
+ expression=metric["sqlExpression"],
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor,
)
- try:
- expression = sanitize_clause(expression)
- except QueryClauseValidationException as ex:
- raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
@@ -953,21 +974,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
:rtype: sqlalchemy.sql.column
"""
label = utils.get_column_name(col)
- expression = col["sqlExpression"]
- if template_processor and expression:
- expression = template_processor.process_template(expression)
- if expression:
- expression = validate_adhoc_subquery(
- expression,
- self.database_id,
- self.schema,
- )
- try:
- expression = sanitize_clause(expression)
- except QueryClauseValidationException as ex:
- raise QueryObjectValidationError(ex.message) from ex
- sqla_metric = literal_column(expression)
- return self.make_sqla_column_compatible(sqla_metric, label)
+ expression = _process_sql_expression(
+ expression=col["sqlExpression"],
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor,
+ )
+ sqla_column = literal_column(expression)
+ return self.make_sqla_column_compatible(sqla_column, label)
def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
@@ -1151,7 +1165,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
- metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name))
+ metrics_exprs.append(
+ self.adhoc_metric_to_sqla(
+ metric=metric,
+ columns_by_name=columns_by_name,
+ template_processor=template_processor,
+ )
+ )
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
else:
@@ -1178,10 +1198,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
- col["sqlExpression"] = validate_adhoc_subquery(
- cast(str, col["sqlExpression"]),
- self.database_id,
- self.schema,
+ col["sqlExpression"] = _process_sql_expression(
+ expression=col["sqlExpression"],
+ database_id=self.database_id,
+ schema=self.schema,
+ template_processor=template_processor,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists