You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2019/07/20 16:12:58 UTC

[incubator-superset] branch master updated: Add cache_key_wrapper to Jinja template processor (#7816)

This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 4568b2a  Add cache_key_wrapper to Jinja template processor (#7816)
4568b2a is described below

commit 4568b2a532069bdd5595c8f79a55d20e6326f9e4
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Sat Jul 20 19:12:35 2019 +0300

    Add cache_key_wrapper to Jinja template processor (#7816)
---
 docs/sqllab.rst                    |  2 ++
 superset/common/query_context.py   |  7 ++++-
 superset/common/query_object.py    |  2 +-
 superset/connectors/base/models.py | 13 +++++++---
 superset/connectors/sqla/models.py | 16 +++++++++---
 superset/jinja_context.py          | 52 +++++++++++++++++++++++++++++++++++---
 superset/viz.py                    |  1 +
 tests/sqla_models_tests.py         | 21 ++++++++++++++-
 8 files changed, 102 insertions(+), 12 deletions(-)

diff --git a/docs/sqllab.rst b/docs/sqllab.rst
index 5fe24ad..c60123f 100644
--- a/docs/sqllab.rst
+++ b/docs/sqllab.rst
@@ -87,6 +87,8 @@ Superset's Jinja context:
 
 .. autofunction:: superset.jinja_context.filter_values
 
+.. autofunction:: superset.jinja_context.CacheKeyWrapper.cache_key_wrapper
+
 .. autoclass:: superset.jinja_context.PrestoTemplateProcessor
     :members:
 
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 4992405..e8f7f62 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -152,8 +152,13 @@ class QueryContext:
 
     def get_df_payload(self, query_obj, **kwargs):
         """Handles caching around the df paylod retrieval"""
+        extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj)
         cache_key = (
-            query_obj.cache_key(datasource=self.datasource.uid, **kwargs)
+            query_obj.cache_key(
+                datasource=self.datasource.uid,
+                extra_cache_keys=extra_cache_keys,
+                **kwargs
+            )
             if query_obj
             else None
         )
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index f413031..7d72aa5 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -107,7 +107,7 @@ class QueryObject:
 
     def cache_key(self, **extra):
         """
-        The cache key is made out of the key/values in `query_obj`, plus any
+        The cache key is made out of the key/values from to_dict(), plus any
         other key/values in `extra`
         We remove datetime bounds that are hard values, and replace them with
         the use-provided inputs to bounds, which may be time-relative (as in
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index e90d5d3..da7fec5 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=C,R,W
 import json
+from typing import Any, List
 
 from sqlalchemy import and_, Boolean, Column, Integer, String, Text
 from sqlalchemy.ext.declarative import declared_attr
@@ -73,9 +74,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
         )
 
     # placeholder for a relationship to a derivative of BaseColumn
-    columns = []
+    columns: List[Any] = []
     # placeholder for a relationship to a derivative of BaseMetric
-    metrics = []
+    metrics: List[Any] = []
 
     @property
     def uid(self):
@@ -329,6 +330,12 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
             obj.get("columns"), self.columns, self.column_class, "column_name"
         )
 
+    def get_extra_cache_keys(self, query_obj) -> List[Any]:
+        """ If a datasource needs to provide additional keys for calculation of
+        cache keys, those can be provided via this method
+        """
+        return []
+
 
 class BaseColumn(AuditMixinNullable, ImportMixin):
     """Interface for column"""
@@ -346,7 +353,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
     is_dttm = None
 
     # [optional] Set this to support import/export functionality
-    export_fields = []
+    export_fields: List[Any] = []
 
     def __repr__(self):
         return self.column_name
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index c9eed47..dfedca3 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -18,7 +18,7 @@
 from collections import namedtuple, OrderedDict
 from datetime import datetime
 import logging
-from typing import Optional, Union
+from typing import Any, List, Optional, Union
 
 from flask import escape, Markup
 from flask_appbuilder import Model
@@ -61,7 +61,9 @@ from superset.utils import core as utils, import_datasource
 config = app.config
 metadata = Model.metadata  # pylint: disable=no-member
 
-SqlaQuery = namedtuple("SqlaQuery", ["sqla_query", "labels_expected"])
+SqlaQuery = namedtuple(
+    "SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"]
+)
 QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"])
 
 
@@ -618,6 +620,8 @@ class SqlaTable(Model, BaseDatasource):
             "columns": {col.column_name: col for col in self.columns},
         }
         template_kwargs.update(self.template_params_dict)
+        extra_cache_keys: List[Any] = []
+        template_kwargs["extra_cache_keys"] = extra_cache_keys
         template_processor = self.get_template_processor(**template_kwargs)
         db_engine_spec = self.database.db_engine_spec
 
@@ -869,7 +873,9 @@ class SqlaTable(Model, BaseDatasource):
                 qry = qry.where(top_groups)
 
         return SqlaQuery(
-            sqla_query=qry.select_from(tbl), labels_expected=labels_expected
+            sqla_query=qry.select_from(tbl),
+            labels_expected=labels_expected,
+            extra_cache_keys=extra_cache_keys,
         )
 
     def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
@@ -1058,6 +1064,10 @@ class SqlaTable(Model, BaseDatasource):
     def default_query(qry):
         return qry.filter_by(is_sqllab_view=False)
 
+    def get_extra_cache_keys(self, query_obj) -> List[Any]:
+        sqla_query = self.get_sqla_query(**query_obj)
+        return sqla_query.extra_cache_keys
+
 
 sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
 sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index cfb7593..97b4dfe 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -129,7 +129,43 @@ def filter_values(column: str, default: Optional[str] = None) -> List[str]:
         return []
 
 
-class BaseTemplateProcessor(object):
+class CacheKeyWrapper:
+    """ Dummy class that exposes a method used to store additional values used in
+     calculation of query object cache keys"""
+
+    def __init__(self, extra_cache_keys: Optional[List[Any]] = None):
+        self.extra_cache_keys = extra_cache_keys
+
+    def cache_key_wrapper(self, key: Any) -> Any:
+        """ Adds values to a list that is added to the query object used for calculating
+        a cache key.
+
+        This is needed if the following applies:
+            - Caching is enabled
+            - The query is dynamically generated using a jinja template
+            - A username or similar is used as a filter in the query
+
+        Example when using a SQL query as a data source ::
+
+            SELECT action, count(*) as times
+            FROM logs
+            WHERE logged_in_user = '{{ cache_key_wrapper(current_username()) }}'
+            GROUP BY action
+
+        This will ensure that the query results that were cached by `user_1` will
+        **not** be seen by `user_2`, as the `cache_key` for the query will be
+        different. ``cache_key_wrapper`` can be used similarly for regular table data
+        sources by adding a `Custom SQL` filter.
+
+        :param key: Any value that should be considered when calculating the cache key
+        :return: the original value ``key`` passed to the function
+        """
+        if self.extra_cache_keys is not None:
+            self.extra_cache_keys.append(key)
+        return key
+
+
+class BaseTemplateProcessor:
     """Base class for database-specific jinja context
 
     There's this bit of magic in ``process_template`` that instantiates only
@@ -146,7 +182,14 @@ class BaseTemplateProcessor(object):
 
     engine: Optional[str] = None
 
-    def __init__(self, database=None, query=None, table=None, **kwargs):
+    def __init__(
+        self,
+        database=None,
+        query=None,
+        table=None,
+        extra_cache_keys: Optional[List[Any]] = None,
+        **kwargs
+    ):
         self.database = database
         self.query = query
         self.schema = None
@@ -158,6 +201,7 @@ class BaseTemplateProcessor(object):
             "url_param": url_param,
             "current_user_id": current_user_id,
             "current_username": current_username,
+            "cache_key_wrapper": CacheKeyWrapper(extra_cache_keys).cache_key_wrapper,
             "filter_values": filter_values,
             "form_data": {},
         }
@@ -189,7 +233,9 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
     engine = "presto"
 
     @staticmethod
-    def _schema_table(table_name: str, schema: str) -> Tuple[str, str]:
+    def _schema_table(
+        table_name: str, schema: Optional[str]
+    ) -> Tuple[str, Optional[str]]:
         if "." in table_name:
             schema, table_name = table_name.split(".")
         return table_name, schema
diff --git a/superset/viz.py b/superset/viz.py
index d81e16e..52075be 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -364,6 +364,7 @@ class BaseViz(object):
 
         cache_dict["time_range"] = self.form_data.get("time_range")
         cache_dict["datasource"] = self.datasource.uid
+        cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj)
         json_data = self.json_dumps(cache_dict, sort_keys=True)
         return hashlib.md5(json_data.encode("utf-8")).hexdigest()
 
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 93a2859..f089628 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -14,8 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from superset.connectors.sqla.models import TableColumn
+from superset import db
+from superset.connectors.sqla.models import SqlaTable, TableColumn
 from superset.db_engine_specs.druid import DruidEngineSpec
+from superset.utils.core import get_main_database
 from .base_tests import SupersetTestCase
 
 
@@ -39,3 +41,20 @@ class DatabaseModelTestCase(SupersetTestCase):
 
         col = TableColumn(column_name="foo", type="STRING")
         self.assertEquals(col.is_time, False)
+
+    def test_cache_key_wrapper(self):
+        query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
+        table = SqlaTable(sql=query, database=get_main_database(db.session))
+        query_obj = {
+            "granularity": None,
+            "from_dttm": None,
+            "to_dttm": None,
+            "groupby": ["user"],
+            "metrics": [],
+            "is_timeseries": False,
+            "filter": [],
+            "is_prequery": False,
+            "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
+        }
+        extra_cache_keys = table.get_extra_cache_keys(query_obj)
+        self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])