You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2021/01/29 14:25:40 UTC

[superset] 02/18: fix(async queries): Remove "force" param on cached data retrieval (#12103)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 1.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 73523e282278ac85d9e72e0787b1737c44d498cf
Author: Rob DiCiuccio <ro...@gmail.com>
AuthorDate: Wed Jan 27 10:16:57 2021 -0800

    fix(async queries): Remove "force" param on cached data retrieval (#12103)
    
    * Async queries: remove force cache param on data retrieval
    
    * Assert equal query_object cache keys
    
    * Decouple etag_cache from permission checks
    
    * Fix query_context test
    
    * Use marshmallow EnumField for validation
---
 setup.cfg                        |  2 +-
 superset/charts/schemas.py       | 14 ++++++--------
 superset/common/query_context.py |  5 ++---
 superset/utils/cache.py          |  7 +------
 superset/views/core.py           |  9 ++++++---
 superset/views/utils.py          | 18 +++++++++++++++++
 superset/viz.py                  | 16 ++++-----------
 tests/query_context_tests.py     | 42 ++++++++++++++++++++++++++++++++++------
 8 files changed, 74 insertions(+), 39 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index c4d4140..9dd35f5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -30,7 +30,7 @@ combine_as_imports = true
 include_trailing_comma = true
 line_length = 88
 known_first_party = superset
-known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,p [...]
+known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison [...]
 multi_line_output = 3
 order_by_type = false
 
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index f85ad00..d62707e 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -19,11 +19,14 @@ from typing import Any, Dict
 from flask_babel import gettext as _
 from marshmallow import EXCLUDE, fields, post_load, Schema, validate
 from marshmallow.validate import Length, Range
+from marshmallow_enum import EnumField
 
 from superset.common.query_context import QueryContext
 from superset.utils import schema as utils
 from superset.utils.core import (
     AnnotationType,
+    ChartDataResultFormat,
+    ChartDataResultType,
     FilterOperator,
     PostProcessingBoxplotWhiskerType,
     PostProcessingContributionOrientation,
@@ -1012,14 +1015,9 @@ class ChartDataQueryContextSchema(Schema):
         description="Should the queries be forced to load from the source. "
         "Default: `false`",
     )
-    result_type = fields.String(
-        description="Type of results to return",
-        validate=validate.OneOf(choices=("full", "query", "results", "samples")),
-    )
-    result_format = fields.String(
-        description="Format of result payload",
-        validate=validate.OneOf(choices=("json", "csv")),
-    )
+
+    result_type = EnumField(ChartDataResultType, by_value=True)
+    result_format = EnumField(ChartDataResultFormat, by_value=True)
 
     # pylint: disable=no-self-use,unused-argument
     @post_load
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 0411007..3c20813 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -85,9 +85,8 @@ class QueryContext:
         self.cache_values = {
             "datasource": datasource,
             "queries": queries,
-            "force": force,
-            "result_type": result_type,
-            "result_format": result_format,
+            "result_type": self.result_type,
+            "result_format": self.result_format,
         }
 
     def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index 729c316..66da468 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -125,9 +125,7 @@ def memoized_func(
 
 
 def etag_cache(
-    check_perms: Callable[..., Any],
-    cache: Cache = cache_manager.cache,
-    max_age: Optional[Union[int, float]] = None,
+    cache: Cache = cache_manager.cache, max_age: Optional[Union[int, float]] = None,
 ) -> Callable[..., Any]:
     """
     A decorator for caching views and handling etag conditional requests.
@@ -147,9 +145,6 @@ def etag_cache(
     def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
         @wraps(f)
         def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
-            # check if the user can access the resource
-            check_perms(*args, **kwargs)
-
             # for POST requests we can't set cache headers, use the response
             # cache nor use conditional requests; this will still use the
             # dataframe cache in `superset/viz.py`, though.
diff --git a/superset/views/core.py b/superset/views/core.py
index 2ef0e1a..5769cf8 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -125,6 +125,7 @@ from superset.views.utils import (
     bootstrap_user_data,
     check_datasource_perms,
     check_explore_cache_perms,
+    check_resource_permissions,
     check_slice_perms,
     get_cta_schema_name,
     get_dashboard_extra_filters,
@@ -456,7 +457,8 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
     @api
     @has_access_api
     @expose("/slice_json/<int:slice_id>")
-    @etag_cache(check_perms=check_slice_perms)
+    @etag_cache()
+    @check_resource_permissions(check_slice_perms)
     def slice_json(self, slice_id: int) -> FlaskResponse:
         form_data, slc = get_form_data(slice_id, use_slice_data=True)
         if not slc:
@@ -508,7 +510,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
     @handle_api_exception
     @permission_name("explore_json")
     @expose("/explore_json/data/<cache_key>", methods=["GET"])
-    @etag_cache(check_perms=check_explore_cache_perms)
+    @check_resource_permissions(check_explore_cache_perms)
     def explore_json_data(self, cache_key: str) -> FlaskResponse:
         """Serves cached result data for async explore_json calls
 
@@ -552,7 +554,8 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         methods=EXPLORE_JSON_METHODS,
     )
     @expose("/explore_json/", methods=EXPLORE_JSON_METHODS)
-    @etag_cache(check_perms=check_datasource_perms)
+    @etag_cache()
+    @check_resource_permissions(check_datasource_perms)
     def explore_json(
         self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
     ) -> FlaskResponse:
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 3ea253c..3162e14 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -17,6 +17,7 @@
 import logging
 from collections import defaultdict
 from datetime import date
+from functools import wraps
 from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union
 from urllib import parse
 
@@ -437,6 +438,23 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
     return obj and user in obj.owners
 
 
+def check_resource_permissions(check_perms: Callable[..., Any],) -> Callable[..., Any]:
+    """
+    A decorator for checking permissions on a request using the passed-in function.
+    """
+
+    def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
+        @wraps(f)
+        def wrapper(*args: Any, **kwargs: Any) -> None:
+            # check if the user can access the resource
+            check_perms(*args, **kwargs)
+            return f(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
 def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
     """
     Loads async explore_json request data from cache and performs access check
diff --git a/superset/viz.py b/superset/viz.py
index d8d799d..a2e2b98 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -143,13 +143,6 @@ class BaseViz:
         self._force_cached = force_cached
         self.from_dttm: Optional[datetime] = None
         self.to_dttm: Optional[datetime] = None
-
-        # Keeping track of whether some data came from cache
-        # this is useful to trigger the <CachedLabel /> when
-        # in the cases where visualization have many queries
-        # (FilterBox for instance)
-        self._any_cache_key: Optional[str] = None
-        self._any_cached_dttm: Optional[str] = None
         self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
 
         self.process_metrics()
@@ -496,6 +489,7 @@ class BaseViz:
         if not query_obj:
             query_obj = self.query_obj()
         cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None
+        cache_value = None
         logger.info("Cache key: {}".format(cache_key))
         is_loaded = False
         stacktrace = None
@@ -507,8 +501,6 @@ class BaseViz:
                 try:
                     df = cache_value["df"]
                     self.query = cache_value["query"]
-                    self._any_cached_dttm = cache_value["dttm"]
-                    self._any_cache_key = cache_key
                     self.status = utils.QueryStatus.SUCCESS
                     is_loaded = True
                     stats_logger.incr("loaded_from_cache")
@@ -583,13 +575,13 @@ class BaseViz:
                     self.datasource.uid,
                 )
         return {
-            "cache_key": self._any_cache_key,
-            "cached_dttm": self._any_cached_dttm,
+            "cache_key": cache_key,
+            "cached_dttm": cache_value["dttm"] if cache_value is not None else None,
             "cache_timeout": self.cache_timeout,
             "df": df,
             "errors": self.errors,
             "form_data": self.form_data,
-            "is_cached": self._any_cache_key is not None,
+            "is_cached": cache_value is not None,
             "query": self.query,
             "from_dttm": self.from_dttm,
             "to_dttm": self.to_dttm,
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index 2201900..99bc942 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -19,6 +19,8 @@ import pytest
 from superset import db
 from superset.charts.schemas import ChartDataQueryContextSchema
 from superset.connectors.connector_registry import ConnectorRegistry
+from superset.extensions import cache_manager
+from superset.models.cache import CacheKey
 from superset.utils.core import (
     AdhocMetricExpressionType,
     ChartDataResultFormat,
@@ -68,11 +70,39 @@ class TestQueryContext(SupersetTestCase):
                 self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
                 self.assertEqual(post_proc["options"], payload_post_proc["options"])
 
-    def test_cache_key_changes_when_datasource_is_updated(self):
+    def test_cache(self):
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id)
+        payload["force"] = True
+
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        query_cache_key = query_context.query_cache_key(query_object)
+
+        response = query_context.get_payload(cache_query_context=True)
+        cache_key = response["cache_key"]
+        assert cache_key is not None
+
+        cached = cache_manager.cache.get(cache_key)
+        assert cached is not None
+
+        rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
+        rehydrated_qo = rehydrated_qc.queries[0]
+        rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
+
+        self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
+        self.assertEqual(len(rehydrated_qc.queries), 1)
+        self.assertEqual(query_cache_key, rehydrated_query_cache_key)
+        self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
+        self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
+        self.assertFalse(rehydrated_qc.force)
+
+    def test_query_cache_key_changes_when_datasource_is_updated(self):
         self.login(username="admin")
         payload = get_query_context("birth_names")
 
-        # construct baseline cache_key
+        # construct baseline query_cache_key
         query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_original = query_context.query_cache_key(query_object)
@@ -89,7 +119,7 @@ class TestQueryContext(SupersetTestCase):
         datasource.description = description_original
         db.session.commit()
 
-        # create new QueryContext with unchanged attributes and extract new cache_key
+        # create new QueryContext with unchanged attributes, extract new query_cache_key
         query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_new = query_context.query_cache_key(query_object)
@@ -97,16 +127,16 @@ class TestQueryContext(SupersetTestCase):
         # the new cache_key should be different due to updated datasource
         self.assertNotEqual(cache_key_original, cache_key_new)
 
-    def test_cache_key_changes_when_post_processing_is_updated(self):
+    def test_query_cache_key_changes_when_post_processing_is_updated(self):
         self.login(username="admin")
         payload = get_query_context("birth_names", add_postprocessing_operations=True)
 
-        # construct baseline cache_key from query_context with post processing operation
+        # construct baseline query_cache_key from query_context with post processing operation
         query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_original = query_context.query_cache_key(query_object)
 
-        # ensure added None post_processing operation doesn't change cache_key
+        # ensure added None post_processing operation doesn't change query_cache_key
         payload["queries"][0]["post_processing"].append(None)
         query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]