You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@superset.apache.org by GitBox <gi...@apache.org> on 2018/02/07 22:49:23 UTC

[GitHub] mistercrunch closed pull request #4316: Fix caching related issues

mistercrunch closed pull request #4316: Fix caching related issues
URL: https://github.com/apache/incubator-superset/pull/4316
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/assets/package.json b/superset/assets/package.json
index c944ad2fa0..abc978c079 100644
--- a/superset/assets/package.json
+++ b/superset/assets/package.json
@@ -93,8 +93,8 @@
     "react-sortable-hoc": "^0.6.7",
     "react-split-pane": "^0.1.66",
     "react-syntax-highlighter": "^5.7.0",
-    "react-virtualized": "^9.3.0",
-    "react-virtualized-select": "^2.4.0",
+    "react-virtualized": "9.3.0",
+    "react-virtualized-select": "2.4.0",
     "reactable": "^0.14.1",
     "redux": "^3.5.2",
     "redux-localstorage": "^0.4.1",
diff --git a/superset/models/core.py b/superset/models/core.py
index 142482bdbd..cfd6d75203 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -230,7 +230,7 @@ def slice_link(self):
         name = escape(self.slice_name)
         return Markup('<a href="{url}">{name}</a>'.format(**locals()))
 
-    def get_viz(self):
+    def get_viz(self, force=False):
         """Creates :py:class:viz.BaseViz object from the url_params_multidict.
 
         :return: object of the 'viz_type' type that is taken from the
@@ -246,6 +246,7 @@ def get_viz(self):
         return viz_types[slice_params.get('viz_type')](
             self.datasource,
             form_data=slice_params,
+            force=force,
         )
 
     @classmethod
diff --git a/superset/views/core.py b/superset/views/core.py
index 03636db826..60990b4018 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -942,7 +942,9 @@ def get_viz(
             slice_id=None,
             form_data=None,
             datasource_type=None,
-            datasource_id=None):
+            datasource_id=None,
+            force=False,
+    ):
         if slice_id:
             slc = (
                 db.session.query(models.Slice)
@@ -957,6 +959,7 @@ def get_viz(
             viz_obj = viz.viz_types[viz_type](
                 datasource,
                 form_data=form_data,
+                force=force,
             )
             return viz_obj
 
@@ -1005,7 +1008,9 @@ def generate_json(self, datasource_type, datasource_id, form_data,
             viz_obj = self.get_viz(
                 datasource_type=datasource_type,
                 datasource_id=datasource_id,
-                form_data=form_data)
+                form_data=form_data,
+                force=force,
+            )
         except Exception as e:
             logging.exception(e)
             return json_error_response(
@@ -1026,7 +1031,7 @@ def generate_json(self, datasource_type, datasource_id, form_data,
             return self.get_query_string_response(viz_obj)
 
         try:
-            payload = viz_obj.get_payload(force=force)
+            payload = viz_obj.get_payload()
         except Exception as e:
             logging.exception(e)
             return json_error_response(utils.error_msg_from_exception(e))
@@ -1070,9 +1075,10 @@ def annotation_json(self, layer_id):
         viz_obj = viz.viz_types['table'](
           datasource,
           form_data=form_data,
+          force=False,
         )
         try:
-            payload = viz_obj.get_payload(force=False)
+            payload = viz_obj.get_payload()
         except Exception as e:
             logging.exception(e)
             return json_error_response(utils.error_msg_from_exception(e))
@@ -1864,8 +1870,8 @@ def warm_up_cache(self):
 
         for slc in slices:
             try:
-                obj = slc.get_viz()
-                obj.get_json(force=True)
+                obj = slc.get_viz(force=True)
+                obj.get_json()
             except Exception as e:
                 return json_error_response(utils.error_msg_from_exception(e))
         return json_success(json.dumps(
diff --git a/superset/viz.py b/superset/viz.py
index 6fc8cf9887..2154bcc328 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -18,10 +18,9 @@
 import math
 import traceback
 import uuid
-import zlib
 
 from dateutil import relativedelta as rdelta
-from flask import request
+from flask import escape, request
 from flask_babel import lazy_gettext as _
 import geohash
 from markdown import markdown
@@ -30,8 +29,8 @@
 from pandas.tseries.frequencies import to_offset
 import polyline
 import simplejson as json
-from six import PY3, string_types, text_type
-from six.moves import reduce
+from six import string_types, text_type
+from six.moves import cPickle as pkl, reduce
 
 from superset import app, cache, get_manifest_file, utils
 from superset.utils import DTTM_ALIAS, merge_extra_filters
@@ -49,8 +48,9 @@ class BaseViz(object):
     credits = ''
     is_timeseries = False
     default_fillna = 0
+    cache_type = 'df'
 
-    def __init__(self, datasource, form_data):
+    def __init__(self, datasource, form_data, force=False):
         if not datasource:
             raise Exception(_('Viz is missing a datasource'))
         self.datasource = datasource
@@ -67,6 +67,40 @@ def __init__(self, datasource, form_data):
 
         self.status = None
         self.error_message = None
+        self.force = force
+
+        # Keeping track of whether some data came from cache
+        # this is useful to trigerr the <CachedLabel /> when
+        # in the cases where visualization have many queries
+        # (FilterBox for instance)
+        self._some_from_cache = False
+        self._any_cache_key = None
+        self._any_cached_dttm = None
+
+        self.run_extra_queries()
+
+    def run_extra_queries(self):
+        """Lyfecycle method to use when more than one query is needed
+
+        In rare-ish cases, a visualization may need to execute multiple
+        queries. That is the case for FilterBox or for time comparison
+        in Line chart for instance.
+
+        In those cases, we need to make sure these queries run before the
+        main `get_payload` method gets called, so that the overall caching
+        metadata can be right. The way it works here is that if any of
+        the previous `get_df_payload` calls hit the cache, the main
+        payload's metadata will reflect that.
+
+        The multi-query support may need more work to become a first class
+        use case in the framework, and for the UI to reflect the subtleties
+        (show that only some of the queries were served from cache for
+        instance). In the meantime, since multi-query is rare, we treat
+        it with a bit of a hack. Note that the hack became necessary
+        when moving from caching the visualization's data itself, to caching
+        the underlying query(ies).
+        """
+        pass
 
     def get_fillna_for_type(self, col_type):
         """Returns the value for use as filler for a specific Column.type"""
@@ -222,9 +256,9 @@ def cache_timeout(self):
             return self.datasource.database.cache_timeout
         return config.get('CACHE_DEFAULT_TIMEOUT')
 
-    def get_json(self, force=False):
+    def get_json(self):
         return json.dumps(
-            self.get_payload(force),
+            self.get_payload(),
             default=utils.json_int_dttm_ser, ignore_nan=True)
 
     def cache_key(self, query_obj):
@@ -246,64 +280,73 @@ def cache_key(self, query_obj):
         json_data = self.json_dumps(cache_dict, sort_keys=True)
         return hashlib.md5(json_data.encode('utf-8')).hexdigest()
 
-    def get_payload(self, force=False):
-        """Handles caching around the json payload retrieval"""
-        query_obj = self.query_obj()
+    def get_payload(self, query_obj=None):
+        """Returns a payload of metadata and data"""
+        payload = self.get_df_payload(query_obj)
+        df = payload['df']
+        if df is not None:
+            payload['data'] = self.get_data(df)
+        del payload['df']
+        return payload
+
+    def get_df_payload(self, query_obj=None):
+        """Handles caching around the df payload retrieval"""
+        if not query_obj:
+            query_obj = self.query_obj()
         cache_key = self.cache_key(query_obj) if query_obj else None
-        cached_dttm = None
-        data = None
+        logging.info('Cache key: {}'.format(cache_key))
+        is_loaded = False
         stacktrace = None
-        rowcount = None
-        if cache_key and cache and not force:
+        df = None
+        cached_dttm = datetime.utcnow().isoformat().split('.')[0]
+        if cache_key and cache and not self.force:
             cache_value = cache.get(cache_key)
             if cache_value:
                 stats_logger.incr('loaded_from_cache')
-                is_cached = True
                 try:
-                    cache_value = zlib.decompress(cache_value)
-                    if PY3:
-                        cache_value = cache_value.decode('utf-8')
-                    cache_value = json.loads(cache_value)
-                    data = cache_value['data']
-                    cached_dttm = cache_value['dttm']
+                    cache_value = pkl.loads(cache_value)
+                    df = cache_value['df']
+                    is_loaded = True
+                    self._any_cache_key = cache_key
+                    self._any_cached_dttm = cache_value['dttm']
                 except Exception as e:
+                    logging.exception(e)
                     logging.error('Error reading cache: ' +
                                   utils.error_msg_from_exception(e))
-                    data = None
                 logging.info('Serving from cache')
 
-        if not data:
-            stats_logger.incr('loaded_from_source')
-            is_cached = False
+        if query_obj and not is_loaded:
             try:
                 df = self.get_df(query_obj)
-                if not self.error_message:
-                    data = self.get_data(df)
-                rowcount = len(df.index) if df is not None else 0
+                stats_logger.incr('loaded_from_source')
+                is_loaded = True
             except Exception as e:
                 logging.exception(e)
                 if not self.error_message:
-                    self.error_message = str(e)
+                    self.error_message = escape('{}'.format(e))
                 self.status = utils.QueryStatus.FAILED
-                data = None
                 stacktrace = traceback.format_exc()
 
             if (
-                    data and
+                    is_loaded and
                     cache_key and
                     cache and
                     self.status != utils.QueryStatus.FAILED):
-                cached_dttm = datetime.utcnow().isoformat().split('.')[0]
                 try:
-                    cache_value = self.json_dumps({
-                        'data': data,
-                        'dttm': cached_dttm,
-                    })
-                    if PY3:
-                        cache_value = bytes(cache_value, 'utf-8')
+                    cache_value = dict(
+                        dttm=cached_dttm,
+                        df=df if df is not None else None,
+                    )
+                    cache_value = pkl.dumps(
+                        cache_value, protocol=pkl.HIGHEST_PROTOCOL)
+
+                    logging.info('Caching {} chars at key {}'.format(
+                        len(cache_value), cache_key))
+
+                    stats_logger.incr('set_cache_key')
                     cache.set(
                         cache_key,
-                        zlib.compress(cache_value),
+                        cache_value,
                         timeout=self.cache_timeout)
                 except Exception as e:
                     # cache.set call can fail if the backend is down or if
@@ -313,17 +356,17 @@ def get_payload(self, force=False):
                     cache.delete(cache_key)
 
         return {
-            'cache_key': cache_key,
-            'cached_dttm': cached_dttm,
+            'cache_key': self._any_cache_key,
+            'cached_dttm': self._any_cached_dttm,
             'cache_timeout': self.cache_timeout,
-            'data': data,
+            'df': df,
             'error': self.error_message,
             'form_data': self.form_data,
-            'is_cached': is_cached,
+            'is_cached': self._any_cache_key is not None,
             'query': self.query,
             'status': self.status,
             'stacktrace': stacktrace,
-            'rowcount': rowcount,
+            'rowcount': len(df.index) if df is not None else 0,
         }
 
     def json_dumps(self, obj, sort_keys=False):
@@ -412,7 +455,11 @@ def query_obj(self):
 
     def get_data(self, df):
         fd = self.form_data
-        if not self.should_be_timeseries() and DTTM_ALIAS in df:
+        if (
+                not self.should_be_timeseries() and
+                df is not None and
+                DTTM_ALIAS in df
+        ):
             del df[DTTM_ALIAS]
 
         # Sum up and compute percentages for all percent metrics
@@ -1059,12 +1106,10 @@ def process_data(self, df, aggregate=False):
             df = df[num_period_compare:]
         return df
 
-    def get_data(self, df):
+    def run_extra_queries(self):
         fd = self.form_data
-        df = self.process_data(df)
-        chart_data = self.to_series(df)
-
         time_compare = fd.get('time_compare')
+        self.extra_chart_data = None
         if time_compare:
             query_object = self.query_obj()
             delta = utils.parse_human_timedelta(time_compare)
@@ -1073,12 +1118,20 @@ def get_data(self, df):
             query_object['from_dttm'] -= delta
             query_object['to_dttm'] -= delta
 
-            df2 = self.get_df(query_object)
+            df2 = self.get_df_payload(query_object).get('df')
             df2[DTTM_ALIAS] += delta
             df2 = self.process_data(df2)
-            chart_data += self.to_series(
+            self.extra_chart_data = self.to_series(
                 df2, classed='superset', title_suffix='---')
+
+    def get_data(self, df):
+        df = self.process_data(df)
+        chart_data = self.to_series(df)
+
+        if self.extra_chart_data:
+            chart_data += self.extra_chart_data
             chart_data = sorted(chart_data, key=lambda x: x['key'])
+
         return chart_data
 
 
@@ -1556,10 +1609,20 @@ class FilterBoxViz(BaseViz):
     verbose_name = _('Filters')
     is_timeseries = False
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
+    cache_type = 'get_data'
 
     def query_obj(self):
         return None
 
+    def run_extra_queries(self):
+        qry = self.filter_query_obj()
+        filters = [g for g in self.form_data['groupby']]
+        self.dataframes = {}
+        for flt in filters:
+            qry['groupby'] = [flt]
+            df = self.get_df_payload(query_obj=qry).get('df')
+            self.dataframes[flt] = df
+
     def filter_query_obj(self):
         qry = super(FilterBoxViz, self).query_obj()
         groupby = self.form_data.get('groupby')
@@ -1570,12 +1633,10 @@ def filter_query_obj(self):
         return qry
 
     def get_data(self, df):
-        qry = self.filter_query_obj()
-        filters = [g for g in self.form_data['groupby']]
         d = {}
+        filters = [g for g in self.form_data['groupby']]
         for flt in filters:
-            qry['groupby'] = [flt]
-            df = super(FilterBoxViz, self).get_df(qry)
+            df = self.dataframes[flt]
             d[flt] = [{
                 'id': row[0],
                 'text': row[0],


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services