You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/04/08 13:48:02 UTC

[superset] 05/09: fix(sqla): apply jinja to metrics (#19565)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 44eb81e35f769b2f3f3224a3c8e6b2045e48e5cc
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Thu Apr 7 14:04:51 2022 +0300

    fix(sqla): apply jinja to metrics (#19565)
    
    (cherry picked from commit 34b55765c4b0cbd8f0b9f89c6ca0f62f4478270e)
---
 superset/connectors/sqla/models.py | 83 ++++++++++++++++++++++++--------------
 1 file changed, 52 insertions(+), 31 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 8721f6ea81..b8d3a7d091 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -354,6 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
 
         :param time_grain: Optional time grain, e.g. P1Y
         :param label: alias/label that column is expected to have
+        :param template_processor: template processor
         :return: A TimeExpression object wrapped in a Label if supported by db
         """
         label = label or utils.DTTM_ALIAS
@@ -517,6 +518,27 @@ sqlatable_user = Table(
 )
 
 
+def _process_sql_expression(
+    expression: Optional[str],
+    database_id: int,
+    schema: str,
+    template_processor: Optional[BaseTemplateProcessor],
+) -> Optional[str]:
+    if template_processor and expression:
+        expression = template_processor.process_template(expression)
+    if expression:
+        expression = validate_adhoc_subquery(
+            expression,
+            database_id,
+            schema,
+        )
+        try:
+            expression = sanitize_clause(expression)
+        except QueryClauseValidationException as ex:
+            raise QueryObjectValidationError(ex.message) from ex
+    return expression
+
+
 class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-methods
     """An ORM object for SqlAlchemy table references"""
 
@@ -899,13 +921,17 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         return sql
 
     def adhoc_metric_to_sqla(
-        self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn]
+        self,
+        metric: AdhocMetric,
+        columns_by_name: Dict[str, TableColumn],
+        template_processor: Optional[BaseTemplateProcessor] = None,
     ) -> ColumnElement:
         """
         Turn an adhoc metric into a sqlalchemy column.
 
         :param dict metric: Adhoc metric definition
         :param dict columns_by_name: Columns for the current table
+        :param template_processor: template_processor instance
         :returns: The metric defined as a sqlalchemy column
         :rtype: sqlalchemy.sql.column
         """
@@ -922,17 +948,12 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 sqla_column = column(column_name)
             sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
         elif expression_type == utils.AdhocMetricExpressionType.SQL:
-            tp = self.get_template_processor()
-            expression = tp.process_template(cast(str, metric["sqlExpression"]))
-            expression = validate_adhoc_subquery(
-                expression,
-                self.database_id,
-                self.schema,
+            expression = _process_sql_expression(
+                expression=metric["sqlExpression"],
+                database_id=self.database_id,
+                schema=self.schema,
+                template_processor=template_processor,
             )
-            try:
-                expression = sanitize_clause(expression)
-            except QueryClauseValidationException as ex:
-                raise QueryObjectValidationError(ex.message) from ex
             sqla_metric = literal_column(expression)
         else:
             raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
@@ -953,21 +974,14 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         :rtype: sqlalchemy.sql.column
         """
         label = utils.get_column_name(col)
-        expression = col["sqlExpression"]
-        if template_processor and expression:
-            expression = template_processor.process_template(expression)
-        if expression:
-            expression = validate_adhoc_subquery(
-                expression,
-                self.database_id,
-                self.schema,
-            )
-            try:
-                expression = sanitize_clause(expression)
-            except QueryClauseValidationException as ex:
-                raise QueryObjectValidationError(ex.message) from ex
-        sqla_metric = literal_column(expression)
-        return self.make_sqla_column_compatible(sqla_metric, label)
+        expression = _process_sql_expression(
+            expression=col["sqlExpression"],
+            database_id=self.database_id,
+            schema=self.schema,
+            template_processor=template_processor,
+        )
+        sqla_column = literal_column(expression)
+        return self.make_sqla_column_compatible(sqla_column, label)
 
     def make_sqla_column_compatible(
         self, sqla_col: ColumnElement, label: Optional[str] = None
@@ -1151,7 +1165,13 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         for metric in metrics:
             if utils.is_adhoc_metric(metric):
                 assert isinstance(metric, dict)
-                metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name))
+                metrics_exprs.append(
+                    self.adhoc_metric_to_sqla(
+                        metric=metric,
+                        columns_by_name=columns_by_name,
+                        template_processor=template_processor,
+                    )
+                )
             elif isinstance(metric, str) and metric in metrics_by_name:
                 metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
             else:
@@ -1178,10 +1198,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             if isinstance(col, dict):
                 col = cast(AdhocMetric, col)
                 if col.get("sqlExpression"):
-                    col["sqlExpression"] = validate_adhoc_subquery(
-                        cast(str, col["sqlExpression"]),
-                        self.database_id,
-                        self.schema,
+                    col["sqlExpression"] = _process_sql_expression(
+                        expression=col["sqlExpression"],
+                        database_id=self.database_id,
+                        schema=self.schema,
+                        template_processor=template_processor,
                     )
                 if utils.is_adhoc_metric(col):
                     # add adhoc sort by column to columns_by_name if not exists