You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/12/03 04:19:32 UTC

[superset] branch master updated: fix(sqla): use same template processor in all methods (#22280)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ad5147016 fix(sqla): use same template processor in all methods (#22280)
1ad5147016 is described below

commit 1ad514701609785f19b27ad495ba34f3b9fff585
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Sat Dec 3 06:19:25 2022 +0200

    fix(sqla): use same template processor in all methods (#22280)
---
 superset/connectors/sqla/models.py           | 152 +++++++++++++++++++--------
 tests/integration_tests/sqla_models_tests.py |  43 ++++++--
 2 files changed, 139 insertions(+), 56 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index a67e686ff3..fd5942c51a 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=too-many-lines
+from __future__ import annotations
+
 import dataclasses
 import json
 import logging
@@ -222,7 +224,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
     __tablename__ = "table_columns"
     __table_args__ = (UniqueConstraint("table_id", "column_name"),)
     table_id = Column(Integer, ForeignKey("tables.id"))
-    table: "SqlaTable" = relationship(
+    table: SqlaTable = relationship(
         "SqlaTable",
         backref=backref("columns", cascade="all, delete-orphan"),
         foreign_keys=[table_id],
@@ -301,14 +303,18 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
         )
         return column_spec.generic_type if column_spec else None
 
-    def get_sqla_col(self, label: Optional[str] = None) -> Column:
+    def get_sqla_col(
+        self,
+        label: Optional[str] = None,
+        template_processor: Optional[BaseTemplateProcessor] = None,
+    ) -> Column:
         label = label or self.column_name
         db_engine_spec = self.db_engine_spec
         column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
         type_ = column_spec.sqla_type if column_spec else None
-        if self.expression:
-            tp = self.table.get_template_processor()
-            expression = tp.process_template(self.expression)
+        if expression := self.expression:
+            if template_processor:
+                expression = template_processor.process_template(expression)
             col = literal_column(expression, type_=type_)
         else:
             col = column(self.column_name, type_=type_)
@@ -324,8 +330,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
         start_dttm: Optional[DateTime] = None,
         end_dttm: Optional[DateTime] = None,
         label: Optional[str] = "__time",
+        template_processor: Optional[BaseTemplateProcessor] = None,
     ) -> ColumnElement:
-        col = self.get_sqla_col(label=label)
+        col = self.get_sqla_col(label=label, template_processor=template_processor)
         l = []
         if start_dttm:
             l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm)))
@@ -358,10 +365,9 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
         if not self.expression and not time_grain and not is_epoch:
             sqla_col = column(self.column_name, type_=type_)
             return self.table.make_sqla_column_compatible(sqla_col, label)
-        if self.expression:
-            expression = self.expression
+        if expression := self.expression:
             if template_processor:
-                expression = template_processor.process_template(self.expression)
+                expression = template_processor.process_template(expression)
             col = literal_column(expression, type_=type_)
         else:
             col = column(self.column_name, type_=type_)
@@ -458,10 +464,17 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
     def __repr__(self) -> str:
         return str(self.metric_name)
 
-    def get_sqla_col(self, label: Optional[str] = None) -> Column:
+    def get_sqla_col(
+        self,
+        label: Optional[str] = None,
+        template_processor: Optional[BaseTemplateProcessor] = None,
+    ) -> Column:
         label = label or self.metric_name
-        tp = self.table.get_template_processor()
-        sqla_col: ColumnClause = literal_column(tp.process_template(self.expression))
+        expression = self.expression
+        if template_processor:
+            expression = template_processor.process_template(expression)
+
+        sqla_col: ColumnClause = literal_column(expression)
         return self.table.make_sqla_column_compatible(sqla_col, label)
 
     @property
@@ -650,7 +663,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         datasource_name: str,
         schema: Optional[str],
         database_name: str,
-    ) -> Optional["SqlaTable"]:
+    ) -> Optional[SqlaTable]:
         schema = schema or None
         query = (
             session.query(cls)
@@ -778,10 +791,17 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         except (TypeError, json.JSONDecodeError):
             return {}
 
-    def get_fetch_values_predicate(self) -> TextClause:
-        tp = self.get_template_processor()
+    def get_fetch_values_predicate(
+        self,
+        template_processor: Optional[BaseTemplateProcessor] = None,
+    ) -> TextClause:
+        fetch_values_predicate = self.fetch_values_predicate
+        if template_processor:
+            fetch_values_predicate = template_processor.process_template(
+                fetch_values_predicate
+            )
         try:
-            return self.text(tp.process_template(self.fetch_values_predicate))
+            return self.text(fetch_values_predicate)
         except TemplateError as ex:
             raise QueryObjectValidationError(
                 _(
@@ -799,12 +819,16 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         tp = self.get_template_processor()
         tbl, cte = self.get_from_clause(tp)
 
-        qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
+        qry = (
+            select([target_col.get_sqla_col(template_processor=tp)])
+            .select_from(tbl)
+            .distinct()
+        )
         if limit:
             qry = qry.limit(limit)
 
         if self.fetch_values_predicate:
-            qry = qry.where(self.get_fetch_values_predicate())
+            qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
 
         with self.database.get_sqla_engine_with_context() as engine:
             sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
@@ -936,7 +960,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             column_name = cast(str, metric_column.get("column_name"))
             table_column: Optional[TableColumn] = columns_by_name.get(column_name)
             if table_column:
-                sqla_column = table_column.get_sqla_col()
+                sqla_column = table_column.get_sqla_col(
+                    template_processor=template_processor
+                )
             else:
                 sqla_column = column(column_name)
             sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
@@ -975,7 +1001,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         )
         col_in_metadata = self.get_column(expression)
         if col_in_metadata:
-            sqla_column = col_in_metadata.get_sqla_col()
+            sqla_column = col_in_metadata.get_sqla_col(
+                template_processor=template_processor
+            )
             is_dttm = col_in_metadata.is_temporal
         else:
             sqla_column = literal_column(expression)
@@ -1190,7 +1218,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                     )
                 )
             elif isinstance(metric, str) and metric in metrics_by_name:
-                metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
+                metrics_exprs.append(
+                    metrics_by_name[metric].get_sqla_col(
+                        template_processor=template_processor
+                    )
+                )
             else:
                 raise QueryObjectValidationError(
                     _("Metric '%(metric)s' does not exist", metric=metric)
@@ -1229,12 +1261,16 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                     col = metrics_exprs_by_expr.get(str(col), col)
                     need_groupby = True
             elif col in columns_by_name:
-                col = columns_by_name[col].get_sqla_col()
+                col = columns_by_name[col].get_sqla_col(
+                    template_processor=template_processor
+                )
             elif col in metrics_exprs_by_label:
                 col = metrics_exprs_by_label[col]
                 need_groupby = True
             elif col in metrics_by_name:
-                col = metrics_by_name[col].get_sqla_col()
+                col = metrics_by_name[col].get_sqla_col(
+                    template_processor=template_processor
+                )
                 need_groupby = True
 
             if isinstance(col, ColumnElement):
@@ -1268,7 +1304,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                         )
                     # if groupby field equals a selected column
                     elif selected in columns_by_name:
-                        outer = columns_by_name[selected].get_sqla_col()
+                        outer = columns_by_name[selected].get_sqla_col(
+                            template_processor=template_processor
+                        )
                     else:
                         selected = validate_adhoc_subquery(
                             selected,
@@ -1302,7 +1340,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                     self.schema,
                 )
                 select_exprs.append(
-                    columns_by_name[selected].get_sqla_col()
+                    columns_by_name[selected].get_sqla_col(
+                        template_processor=template_processor
+                    )
                     if isinstance(selected, str) and selected in columns_by_name
                     else self.make_sqla_column_compatible(
                         literal_column(selected), _column_label
@@ -1336,11 +1376,18 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             ):
                 time_filters.append(
                     columns_by_name[self.main_dttm_col].get_time_filter(
-                        from_dttm,
-                        to_dttm,
+                        start_dttm=from_dttm,
+                        end_dttm=to_dttm,
+                        template_processor=template_processor,
                     )
                 )
-            time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
+            time_filters.append(
+                dttm_col.get_time_filter(
+                    start_dttm=from_dttm,
+                    end_dttm=to_dttm,
+                    template_processor=template_processor,
+                )
+            )
 
         # Always remove duplicates by column name, as sometimes `metrics_exprs`
         # can have the same name as a groupby column (e.g. when users use
@@ -1396,7 +1443,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                         time_grain=filter_grain, template_processor=template_processor
                     )
                 elif col_obj:
-                    sqla_col = col_obj.get_sqla_col()
+                    sqla_col = col_obj.get_sqla_col(
+                        template_processor=template_processor
+                    )
                 col_type = col_obj.type if col_obj else None
                 col_spec = db_engine_spec.get_column_spec(
                     native_type=col_type,
@@ -1521,6 +1570,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                                 start_dttm=_since,
                                 end_dttm=_until,
                                 label=sqla_col.key,
+                                template_processor=template_processor,
                             )
                         )
                     else:
@@ -1565,7 +1615,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 having_clause_and += [self.text(having)]
 
         if apply_fetch_values_predicate and self.fetch_values_predicate:
-            qry = qry.where(self.get_fetch_values_predicate())
+            qry = qry.where(
+                self.get_fetch_values_predicate(template_processor=template_processor)
+            )
         if granularity:
             qry = qry.where(and_(*(time_filters + where_clause_and)))
         else:
@@ -1617,8 +1669,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 if dttm_col and not db_engine_spec.time_groupby_inline:
                     inner_time_filter = [
                         dttm_col.get_time_filter(
-                            inner_from_dttm or from_dttm,
-                            inner_to_dttm or to_dttm,
+                            start_dttm=inner_from_dttm or from_dttm,
+                            end_dttm=inner_to_dttm or to_dttm,
+                            template_processor=template_processor,
                         )
                     ]
                 subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
@@ -1627,7 +1680,10 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 ob = inner_main_metric_expr
                 if series_limit_metric:
                     ob = self._get_series_orderby(
-                        series_limit_metric, metrics_by_name, columns_by_name
+                        series_limit_metric=series_limit_metric,
+                        metrics_by_name=metrics_by_name,
+                        columns_by_name=columns_by_name,
+                        template_processor=template_processor,
                     )
                 direction = desc if order_desc else asc
                 subq = subq.order_by(direction(ob))
@@ -1647,9 +1703,10 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                     orderby = [
                         (
                             self._get_series_orderby(
-                                series_limit_metric,
-                                metrics_by_name,
-                                columns_by_name,
+                                series_limit_metric=series_limit_metric,
+                                metrics_by_name=metrics_by_name,
+                                columns_by_name=columns_by_name,
+                                template_processor=template_processor,
                             ),
                             not order_desc,
                         )
@@ -1709,6 +1766,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         series_limit_metric: Metric,
         metrics_by_name: Dict[str, SqlMetric],
         columns_by_name: Dict[str, TableColumn],
+        template_processor: Optional[BaseTemplateProcessor] = None,
     ) -> Column:
         if utils.is_adhoc_metric(series_limit_metric):
             assert isinstance(series_limit_metric, dict)
@@ -1717,7 +1775,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             isinstance(series_limit_metric, str)
             and series_limit_metric in metrics_by_name
         ):
-            ob = metrics_by_name[series_limit_metric].get_sqla_col()
+            ob = metrics_by_name[series_limit_metric].get_sqla_col(
+                template_processor=template_processor
+            )
         else:
             raise QueryObjectValidationError(
                 _("Metric '%(metric)s' does not exist", metric=series_limit_metric)
@@ -1930,7 +1990,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         database: Database,
         datasource_name: str,
         schema: Optional[str] = None,
-    ) -> List["SqlaTable"]:
+    ) -> List[SqlaTable]:
         query = (
             session.query(cls)
             .filter_by(database_id=database.id)
@@ -1947,7 +2007,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         database: Database,
         permissions: Set[str],
         schema_perms: Set[str],
-    ) -> List["SqlaTable"]:
+    ) -> List[SqlaTable]:
         # TODO(hughhhh): add unit test
         return (
             session.query(cls)
@@ -1964,7 +2024,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
     @classmethod
     def get_eager_sqlatable_datasource(
         cls, session: Session, datasource_id: int
-    ) -> "SqlaTable":
+    ) -> SqlaTable:
         """Returns SqlaTable with columns and metrics."""
         return (
             session.query(cls)
@@ -1977,7 +2037,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         )
 
     @classmethod
-    def get_all_datasources(cls, session: Session) -> List["SqlaTable"]:
+    def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
         qry = session.query(cls)
         qry = cls.default_query(qry)
         return qry.all()
@@ -2038,7 +2098,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
     def before_update(
         mapper: Mapper,  # pylint: disable=unused-argument
         connection: Connection,  # pylint: disable=unused-argument
-        target: "SqlaTable",
+        target: SqlaTable,
     ) -> None:
         """
         Check before update if the target table already exists.
@@ -2110,7 +2170,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
     def after_insert(
         mapper: Mapper,
         connection: Connection,
-        sqla_table: "SqlaTable",
+        sqla_table: SqlaTable,
     ) -> None:
         """
         Update dataset permissions after insert
@@ -2124,7 +2184,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
     def after_delete(
         mapper: Mapper,
         connection: Connection,
-        sqla_table: "SqlaTable",
+        sqla_table: SqlaTable,
     ) -> None:
         """
         Update dataset permissions after delete
@@ -2135,7 +2195,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
     def after_update(
         mapper: Mapper,
         connection: Connection,
-        sqla_table: "SqlaTable",
+        sqla_table: SqlaTable,
     ) -> None:
         """
         Update dataset permissions
@@ -2170,7 +2230,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             return
 
     def write_shadow_dataset(
-        self: "SqlaTable",
+        self: SqlaTable,
     ) -> None:
         """
         This method is deprecated
diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py
index 3088bdfb02..dfba16179d 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -201,22 +201,34 @@ class TestDatabaseModel(SupersetTestCase):
             "granularity": None,
             "from_dttm": None,
             "to_dttm": None,
-            "groupby": ["user", "expr"],
+            "columns": [
+                "user",
+                "expr",
+                {
+                    "hasCustomLabel": True,
+                    "label": "adhoc_column",
+                    "sqlExpression": "'{{ 'foo_' + time_grain }}'",
+                },
+            ],
             "metrics": [
                 {
+                    "hasCustomLabel": True,
+                    "label": "adhoc_metric",
                     "expressionType": AdhocMetricExpressionType.SQL,
-                    "sqlExpression": "SUM(case when user = '{{ current_username() }}' "
-                    "then 1 else 0 end)",
-                    "label": "SUM(userid)",
-                }
+                    "sqlExpression": "SUM(case when user = '{{ 'user_' + "
+                    "current_username() }}' then 1 else 0 end)",
+                },
+                "count_timegrain",
             ],
             "is_timeseries": False,
             "filter": [],
+            "extras": {"time_grain_sqla": "P1D"},
         }
 
         table = SqlaTable(
             table_name="test_has_jinja_metric_and_expr",
-            sql="SELECT '{{ current_username() }}' as user",
+            sql="SELECT '{{ 'user_' + current_username() }}' as user, "
+            "'{{ 'xyz_' + time_grain }}' as time_grain",
             database=get_example_database(),
         )
         TableColumn(
@@ -226,14 +238,25 @@ class TestDatabaseModel(SupersetTestCase):
             type="VARCHAR(100)",
             table=table,
         )
+        SqlMetric(
+            metric_name="count_timegrain",
+            expression="count('{{ 'bar_' + time_grain }}')",
+            table=table,
+        )
         db.session.commit()
 
         sqla_query = table.get_sqla_query(**base_query_obj)
         query = table.database.compile_sqla_query(sqla_query.sqla_query)
-        # assert expression
-        assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
-        # assert metric
-        assert "SUM(case when user = 'abc' then 1 else 0 end)" in query
+        # assert virtual dataset
+        assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query
+        # assert dataset calculated column
+        assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query
+        # assert adhoc column
+        assert "'foo_P1D'" in query
+        # assert dataset saved metric
+        assert "count('bar_P1D')" in query
+        # assert adhoc metric
+        assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query
         # Cleanup
         db.session.delete(table)
         db.session.commit()