You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2021/01/29 14:25:40 UTC
[superset] 02/18: fix(async queries): Remove "force" param on
cached data retrieval (#12103)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch 1.0
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 73523e282278ac85d9e72e0787b1737c44d498cf
Author: Rob DiCiuccio <ro...@gmail.com>
AuthorDate: Wed Jan 27 10:16:57 2021 -0800
fix(async queries): Remove "force" param on cached data retrieval (#12103)
* Async queries: remove force cache param on data retrieval
* Assert equal query_object cache keys
* Decouple etag_cache from permission checks
* Fix query_context test
* Use marshmallow EnumField for validation
---
setup.cfg | 2 +-
superset/charts/schemas.py | 14 ++++++--------
superset/common/query_context.py | 5 ++---
superset/utils/cache.py | 7 +------
superset/views/core.py | 9 ++++++---
superset/views/utils.py | 18 +++++++++++++++++
superset/viz.py | 16 ++++-----------
tests/query_context_tests.py | 42 ++++++++++++++++++++++++++++++++++------
8 files changed, 74 insertions(+), 39 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index c4d4140..9dd35f5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
-known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,p [...]
+known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison [...]
multi_line_output = 3
order_by_type = false
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index f85ad00..d62707e 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -19,11 +19,14 @@ from typing import Any, Dict
from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from marshmallow.validate import Length, Range
+from marshmallow_enum import EnumField
from superset.common.query_context import QueryContext
from superset.utils import schema as utils
from superset.utils.core import (
AnnotationType,
+ ChartDataResultFormat,
+ ChartDataResultType,
FilterOperator,
PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation,
@@ -1012,14 +1015,9 @@ class ChartDataQueryContextSchema(Schema):
description="Should the queries be forced to load from the source. "
"Default: `false`",
)
- result_type = fields.String(
- description="Type of results to return",
- validate=validate.OneOf(choices=("full", "query", "results", "samples")),
- )
- result_format = fields.String(
- description="Format of result payload",
- validate=validate.OneOf(choices=("json", "csv")),
- )
+
+ result_type = EnumField(ChartDataResultType, by_value=True)
+ result_format = EnumField(ChartDataResultFormat, by_value=True)
# pylint: disable=no-self-use,unused-argument
@post_load
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 0411007..3c20813 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -85,9 +85,8 @@ class QueryContext:
self.cache_values = {
"datasource": datasource,
"queries": queries,
- "force": force,
- "result_type": result_type,
- "result_format": result_format,
+ "result_type": self.result_type,
+ "result_format": self.result_format,
}
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index 729c316..66da468 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -125,9 +125,7 @@ def memoized_func(
def etag_cache(
- check_perms: Callable[..., Any],
- cache: Cache = cache_manager.cache,
- max_age: Optional[Union[int, float]] = None,
+ cache: Cache = cache_manager.cache, max_age: Optional[Union[int, float]] = None,
) -> Callable[..., Any]:
"""
A decorator for caching views and handling etag conditional requests.
@@ -147,9 +145,6 @@ def etag_cache(
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
- # check if the user can access the resource
- check_perms(*args, **kwargs)
-
# for POST requests we can't set cache headers, use the response
# cache nor use conditional requests; this will still use the
# dataframe cache in `superset/viz.py`, though.
diff --git a/superset/views/core.py b/superset/views/core.py
index 2ef0e1a..5769cf8 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -125,6 +125,7 @@ from superset.views.utils import (
bootstrap_user_data,
check_datasource_perms,
check_explore_cache_perms,
+ check_resource_permissions,
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
@@ -456,7 +457,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@api
@has_access_api
@expose("/slice_json/<int:slice_id>")
- @etag_cache(check_perms=check_slice_perms)
+ @etag_cache()
+ @check_resource_permissions(check_slice_perms)
def slice_json(self, slice_id: int) -> FlaskResponse:
form_data, slc = get_form_data(slice_id, use_slice_data=True)
if not slc:
@@ -508,7 +510,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@handle_api_exception
@permission_name("explore_json")
@expose("/explore_json/data/<cache_key>", methods=["GET"])
- @etag_cache(check_perms=check_explore_cache_perms)
+ @check_resource_permissions(check_explore_cache_perms)
def explore_json_data(self, cache_key: str) -> FlaskResponse:
"""Serves cached result data for async explore_json calls
@@ -552,7 +554,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
methods=EXPLORE_JSON_METHODS,
)
@expose("/explore_json/", methods=EXPLORE_JSON_METHODS)
- @etag_cache(check_perms=check_datasource_perms)
+ @etag_cache()
+ @check_resource_permissions(check_datasource_perms)
def explore_json(
self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
) -> FlaskResponse:
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 3ea253c..3162e14 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -17,6 +17,7 @@
import logging
from collections import defaultdict
from datetime import date
+from functools import wraps
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union
from urllib import parse
@@ -437,6 +438,23 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
return obj and user in obj.owners
+def check_resource_permissions(check_perms: Callable[..., Any],) -> Callable[..., Any]:
+ """
+ A decorator for checking permissions on a request using the passed-in function.
+ """
+
+ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
+ @wraps(f)
+ def wrapper(*args: Any, **kwargs: Any) -> None:
+ # check if the user can access the resource
+ check_perms(*args, **kwargs)
+ return f(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
"""
Loads async explore_json request data from cache and performs access check
diff --git a/superset/viz.py b/superset/viz.py
index d8d799d..a2e2b98 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -143,13 +143,6 @@ class BaseViz:
self._force_cached = force_cached
self.from_dttm: Optional[datetime] = None
self.to_dttm: Optional[datetime] = None
-
- # Keeping track of whether some data came from cache
- # this is useful to trigger the <CachedLabel /> when
- # in the cases where visualization have many queries
- # (FilterBox for instance)
- self._any_cache_key: Optional[str] = None
- self._any_cached_dttm: Optional[str] = None
self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
self.process_metrics()
@@ -496,6 +489,7 @@ class BaseViz:
if not query_obj:
query_obj = self.query_obj()
cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None
+ cache_value = None
logger.info("Cache key: {}".format(cache_key))
is_loaded = False
stacktrace = None
@@ -507,8 +501,6 @@ class BaseViz:
try:
df = cache_value["df"]
self.query = cache_value["query"]
- self._any_cached_dttm = cache_value["dttm"]
- self._any_cache_key = cache_key
self.status = utils.QueryStatus.SUCCESS
is_loaded = True
stats_logger.incr("loaded_from_cache")
@@ -583,13 +575,13 @@ class BaseViz:
self.datasource.uid,
)
return {
- "cache_key": self._any_cache_key,
- "cached_dttm": self._any_cached_dttm,
+ "cache_key": cache_key,
+ "cached_dttm": cache_value["dttm"] if cache_value is not None else None,
"cache_timeout": self.cache_timeout,
"df": df,
"errors": self.errors,
"form_data": self.form_data,
- "is_cached": self._any_cache_key is not None,
+ "is_cached": cache_value is not None,
"query": self.query,
"from_dttm": self.from_dttm,
"to_dttm": self.to_dttm,
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index 2201900..99bc942 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -19,6 +19,8 @@ import pytest
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.connectors.connector_registry import ConnectorRegistry
+from superset.extensions import cache_manager
+from superset.models.cache import CacheKey
from superset.utils.core import (
AdhocMetricExpressionType,
ChartDataResultFormat,
@@ -68,11 +70,39 @@ class TestQueryContext(SupersetTestCase):
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
- def test_cache_key_changes_when_datasource_is_updated(self):
+ def test_cache(self):
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id)
+ payload["force"] = True
+
+ query_context = ChartDataQueryContextSchema().load(payload)
+ query_object = query_context.queries[0]
+ query_cache_key = query_context.query_cache_key(query_object)
+
+ response = query_context.get_payload(cache_query_context=True)
+ cache_key = response["cache_key"]
+ assert cache_key is not None
+
+ cached = cache_manager.cache.get(cache_key)
+ assert cached is not None
+
+ rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
+ rehydrated_qo = rehydrated_qc.queries[0]
+ rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
+
+ self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
+ self.assertEqual(len(rehydrated_qc.queries), 1)
+ self.assertEqual(query_cache_key, rehydrated_query_cache_key)
+ self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
+ self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
+ self.assertFalse(rehydrated_qc.force)
+
+ def test_query_cache_key_changes_when_datasource_is_updated(self):
self.login(username="admin")
payload = get_query_context("birth_names")
- # construct baseline cache_key
+ # construct baseline query_cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.query_cache_key(query_object)
@@ -89,7 +119,7 @@ class TestQueryContext(SupersetTestCase):
datasource.description = description_original
db.session.commit()
- # create new QueryContext with unchanged attributes and extract new cache_key
+ # create new QueryContext with unchanged attributes, extract new query_cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.query_cache_key(query_object)
@@ -97,16 +127,16 @@ class TestQueryContext(SupersetTestCase):
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
- def test_cache_key_changes_when_post_processing_is_updated(self):
+ def test_query_cache_key_changes_when_post_processing_is_updated(self):
self.login(username="admin")
payload = get_query_context("birth_names", add_postprocessing_operations=True)
- # construct baseline cache_key from query_context with post processing operation
+ # construct baseline query_cache_key from query_context with post processing operation
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.query_cache_key(query_object)
- # ensure added None post_processing operation doesn't change cache_key
+ # ensure added None post_processing operation doesn't change query_cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]