You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2019/07/20 16:12:58 UTC
[incubator-superset] branch master updated: Add cache_key_wrapper
to Jinja template processor (#7816)
This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 4568b2a Add cache_key_wrapper to Jinja template processor (#7816)
4568b2a is described below
commit 4568b2a532069bdd5595c8f79a55d20e6326f9e4
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Sat Jul 20 19:12:35 2019 +0300
Add cache_key_wrapper to Jinja template processor (#7816)
---
docs/sqllab.rst | 2 ++
superset/common/query_context.py | 7 ++++-
superset/common/query_object.py | 2 +-
superset/connectors/base/models.py | 13 +++++++---
superset/connectors/sqla/models.py | 16 +++++++++---
superset/jinja_context.py | 52 +++++++++++++++++++++++++++++++++++---
superset/viz.py | 1 +
tests/sqla_models_tests.py | 21 ++++++++++++++-
8 files changed, 102 insertions(+), 12 deletions(-)
diff --git a/docs/sqllab.rst b/docs/sqllab.rst
index 5fe24ad..c60123f 100644
--- a/docs/sqllab.rst
+++ b/docs/sqllab.rst
@@ -87,6 +87,8 @@ Superset's Jinja context:
.. autofunction:: superset.jinja_context.filter_values
+.. autofunction:: superset.jinja_context.CacheKeyWrapper.cache_key_wrapper
+
.. autoclass:: superset.jinja_context.PrestoTemplateProcessor
:members:
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 4992405..e8f7f62 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -152,8 +152,13 @@ class QueryContext:
def get_df_payload(self, query_obj, **kwargs):
"""Handles caching around the df paylod retrieval"""
+ extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj)
cache_key = (
- query_obj.cache_key(datasource=self.datasource.uid, **kwargs)
+ query_obj.cache_key(
+ datasource=self.datasource.uid,
+ extra_cache_keys=extra_cache_keys,
+ **kwargs
+ )
if query_obj
else None
)
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index f413031..7d72aa5 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -107,7 +107,7 @@ class QueryObject:
def cache_key(self, **extra):
"""
- The cache key is made out of the key/values in `query_obj`, plus any
+ The cache key is made out of the key/values from to_dict(), plus any
other key/values in `extra`
We remove datetime bounds that are hard values, and replace them with
the use-provided inputs to bounds, which may be time-relative (as in
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index e90d5d3..da7fec5 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=C,R,W
import json
+from typing import Any, List
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
@@ -73,9 +74,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
)
# placeholder for a relationship to a derivative of BaseColumn
- columns = []
+ columns: List[Any] = []
# placeholder for a relationship to a derivative of BaseMetric
- metrics = []
+ metrics: List[Any] = []
@property
def uid(self):
@@ -329,6 +330,12 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
obj.get("columns"), self.columns, self.column_class, "column_name"
)
+ def get_extra_cache_keys(self, query_obj) -> List[Any]:
+ """ If a datasource needs to provide additional keys for calculation of
+ cache keys, those can be provided via this method
+ """
+ return []
+
class BaseColumn(AuditMixinNullable, ImportMixin):
"""Interface for column"""
@@ -346,7 +353,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
is_dttm = None
# [optional] Set this to support import/export functionality
- export_fields = []
+ export_fields: List[Any] = []
def __repr__(self):
return self.column_name
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index c9eed47..dfedca3 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -18,7 +18,7 @@
from collections import namedtuple, OrderedDict
from datetime import datetime
import logging
-from typing import Optional, Union
+from typing import Any, List, Optional, Union
from flask import escape, Markup
from flask_appbuilder import Model
@@ -61,7 +61,9 @@ from superset.utils import core as utils, import_datasource
config = app.config
metadata = Model.metadata # pylint: disable=no-member
-SqlaQuery = namedtuple("SqlaQuery", ["sqla_query", "labels_expected"])
+SqlaQuery = namedtuple(
+ "SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"]
+)
QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"])
@@ -618,6 +620,8 @@ class SqlaTable(Model, BaseDatasource):
"columns": {col.column_name: col for col in self.columns},
}
template_kwargs.update(self.template_params_dict)
+ extra_cache_keys: List[Any] = []
+ template_kwargs["extra_cache_keys"] = extra_cache_keys
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
@@ -869,7 +873,9 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.where(top_groups)
return SqlaQuery(
- sqla_query=qry.select_from(tbl), labels_expected=labels_expected
+ sqla_query=qry.select_from(tbl),
+ labels_expected=labels_expected,
+ extra_cache_keys=extra_cache_keys,
)
def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
@@ -1058,6 +1064,10 @@ class SqlaTable(Model, BaseDatasource):
def default_query(qry):
return qry.filter_by(is_sqllab_view=False)
+ def get_extra_cache_keys(self, query_obj) -> List[Any]:
+ sqla_query = self.get_sqla_query(**query_obj)
+ return sqla_query.extra_cache_keys
+
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index cfb7593..97b4dfe 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -129,7 +129,43 @@ def filter_values(column: str, default: Optional[str] = None) -> List[str]:
return []
-class BaseTemplateProcessor(object):
+class CacheKeyWrapper:
+ """ Dummy class that exposes a method used to store additional values used in
+ calculation of query object cache keys"""
+
+ def __init__(self, extra_cache_keys: Optional[List[Any]] = None):
+ self.extra_cache_keys = extra_cache_keys
+
+ def cache_key_wrapper(self, key: Any) -> Any:
+ """ Adds values to a list that is added to the query object used for calculating
+ a cache key.
+
+ This is needed if the following applies:
+ - Caching is enabled
+ - The query is dynamically generated using a jinja template
+ - A username or similar is used as a filter in the query
+
+ Example when using a SQL query as a data source ::
+
+ SELECT action, count(*) as times
+ FROM logs
+ WHERE logged_in_user = '{{ cache_key_wrapper(current_username()) }}'
+ GROUP BY action
+
+ This will ensure that the query results that were cached by `user_1` will
+ **not** be seen by `user_2`, as the `cache_key` for the query will be
+ different. ``cache_key_wrapper`` can be used similarly for regular table data
+ sources by adding a `Custom SQL` filter.
+
+ :param key: Any value that should be considered when calculating the cache key
+ :return: the original value ``key`` passed to the function
+ """
+ if self.extra_cache_keys is not None:
+ self.extra_cache_keys.append(key)
+ return key
+
+
+class BaseTemplateProcessor:
"""Base class for database-specific jinja context
There's this bit of magic in ``process_template`` that instantiates only
@@ -146,7 +182,14 @@ class BaseTemplateProcessor(object):
engine: Optional[str] = None
- def __init__(self, database=None, query=None, table=None, **kwargs):
+ def __init__(
+ self,
+ database=None,
+ query=None,
+ table=None,
+ extra_cache_keys: Optional[List[Any]] = None,
+ **kwargs
+ ):
self.database = database
self.query = query
self.schema = None
@@ -158,6 +201,7 @@ class BaseTemplateProcessor(object):
"url_param": url_param,
"current_user_id": current_user_id,
"current_username": current_username,
+ "cache_key_wrapper": CacheKeyWrapper(extra_cache_keys).cache_key_wrapper,
"filter_values": filter_values,
"form_data": {},
}
@@ -189,7 +233,9 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
engine = "presto"
@staticmethod
- def _schema_table(table_name: str, schema: str) -> Tuple[str, str]:
+ def _schema_table(
+ table_name: str, schema: Optional[str]
+ ) -> Tuple[str, Optional[str]]:
if "." in table_name:
schema, table_name = table_name.split(".")
return table_name, schema
diff --git a/superset/viz.py b/superset/viz.py
index d81e16e..52075be 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -364,6 +364,7 @@ class BaseViz(object):
cache_dict["time_range"] = self.form_data.get("time_range")
cache_dict["datasource"] = self.datasource.uid
+ cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj)
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index 93a2859..f089628 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from superset.connectors.sqla.models import TableColumn
+from superset import db
+from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
+from superset.utils.core import get_main_database
from .base_tests import SupersetTestCase
@@ -39,3 +41,20 @@ class DatabaseModelTestCase(SupersetTestCase):
col = TableColumn(column_name="foo", type="STRING")
self.assertEquals(col.is_time, False)
+
+ def test_cache_key_wrapper(self):
+ query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
+ table = SqlaTable(sql=query, database=get_main_database(db.session))
+ query_obj = {
+ "granularity": None,
+ "from_dttm": None,
+ "to_dttm": None,
+ "groupby": ["user"],
+ "metrics": [],
+ "is_timeseries": False,
+ "filter": [],
+ "is_prequery": False,
+ "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
+ }
+ extra_cache_keys = table.get_extra_cache_keys(query_obj)
+ self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])