You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2018/05/15 17:37:37 UTC

[incubator-superset] branch master updated: Make MetricsControl the standard across visualizations (#4914)

This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b427d7  Make MetricsControl the standard across visualizations (#4914)
7b427d7 is described below

commit 7b427d7ee09e0a2c313a047dca5255c2efa4c1b1
Author: Maxime Beauchemin <ma...@gmail.com>
AuthorDate: Tue May 15 12:37:34 2018 -0500

    Make MetricsControl the standard across visualizations (#4914)
    
    * [WiP] make MetricsControl the standard across visualizations
    
    This spreads MetricsControl across visualizations.
    
    * Addressing comments
    
    * Fix deepcopy issue using shallow copy
    
    * Fix tests
---
 superset/assets/src/explore/controls.jsx | 117 ++++++++++---------------------
 superset/data/__init__.py                |   4 +-
 superset/viz.py                          |  70 ++++++++++++------
 3 files changed, 87 insertions(+), 104 deletions(-)

diff --git a/superset/assets/src/explore/controls.jsx b/superset/assets/src/explore/controls.jsx
index bc5dc46..66d563b 100644
--- a/superset/assets/src/explore/controls.jsx
+++ b/superset/assets/src/explore/controls.jsx
@@ -47,7 +47,6 @@ import {
 import * as v from './validators';
 import { colorPrimary, ALL_COLOR_SCHEMES, spectrums } from '../modules/colors';
 import { defaultViewport } from '../modules/geo';
-import MetricOption from '../components/MetricOption';
 import ColumnOption from '../components/ColumnOption';
 import OptionDescription from '../components/OptionDescription';
 import { t } from '../locales';
@@ -116,6 +115,32 @@ const groupByControl = {
   },
 };
 
+const metrics = {
+  type: 'MetricsControl',
+  multi: true,
+  label: t('Metrics'),
+  validators: [v.nonEmpty],
+  default: (c) => {
+    const metric = mainMetric(c.savedMetrics);
+    return metric ? [metric] : null;
+  },
+  mapStateToProps: (state) => {
+    const datasource = state.datasource;
+    return {
+      columns: datasource ? datasource.columns : [],
+      savedMetrics: datasource ? datasource.metrics : [],
+      datasourceType: datasource && datasource.type,
+    };
+  },
+  description: t('One or many metrics to display'),
+};
+const metric = {
+  ...metrics,
+  multi: false,
+  label: t('Metric'),
+  default: props => mainMetric(props.savedMetrics),
+};
+
 const sandboxUrl = (
   'https://github.com/apache/incubator-superset/' +
   'blob/master/superset/assets/src/modules/sandbox.js');
@@ -152,6 +177,11 @@ function jsFunctionControl(label, description, extraDescr = null, height = 100,
 }
 
 export const controls = {
+
+  metrics,
+
+  metric,
+
   datasource: {
     type: 'DatasourceControl',
     label: t('Datasource'),
@@ -169,36 +199,11 @@ export const controls = {
     description: t('The type of visualization to display'),
   },
 
-  metrics: {
-    type: 'MetricsControl',
-    multi: true,
-    label: t('Metrics'),
-    validators: [v.nonEmpty],
-    default: (c) => {
-      const metric = mainMetric(c.savedMetrics);
-      return metric ? [metric] : null;
-    },
-    mapStateToProps: (state) => {
-      const datasource = state.datasource;
-      return {
-        columns: datasource ? datasource.columns : [],
-        savedMetrics: datasource ? datasource.metrics : [],
-        datasourceType: datasource && datasource.type,
-      };
-    },
-    description: t('One or many metrics to display'),
-  },
-
   percent_metrics: {
-    type: 'SelectControl',
+    ...metrics,
     multi: true,
     label: t('Percentage Metrics'),
-    valueKey: 'metric_name',
-    optionRenderer: m => <MetricOption metric={m} showType />,
-    valueRenderer: m => <MetricOption metric={m} />,
-    mapStateToProps: state => ({
-      options: (state.datasource) ? state.datasource.metrics : [],
-    }),
+    validators: [],
     description: t('Metrics for which percentage of total are to be displayed'),
   },
 
@@ -262,33 +267,11 @@ export const controls = {
     renderTrigger: true,
   },
 
-  metric: {
-    type: 'MetricsControl',
-    multi: false,
-    label: t('Metric'),
-    clearable: false,
-    validators: [v.nonEmpty],
-    default: props => mainMetric(props.savedMetrics),
-    mapStateToProps: state => ({
-      columns: state.datasource ? state.datasource.columns : [],
-      savedMetrics: state.datasource ? state.datasource.metrics : [],
-      datasourceType: state.datasource && state.datasource.type,
-    }),
-  },
-
   metric_2: {
-    type: 'SelectControl',
+    ...metric,
     label: t('Right Axis Metric'),
-    default: null,
-    validators: [v.nonEmpty],
     clearable: true,
     description: t('Choose a metric for right axis'),
-    valueKey: 'metric_name',
-    optionRenderer: m => <MetricOption metric={m} showType />,
-    valueRenderer: m => <MetricOption metric={m} />,
-    mapStateToProps: state => ({
-      options: (state.datasource) ? state.datasource.metrics : [],
-    }),
   },
 
   stacked_style: {
@@ -508,13 +491,10 @@ export const controls = {
   },
 
   secondary_metric: {
-    type: 'SelectControl',
+    ...metric,
     label: t('Color Metric'),
     default: null,
     description: t('A metric to use for color'),
-    mapStateToProps: state => ({
-      choices: (state.datasource) ? state.datasource.metrics_combo : [],
-    }),
   },
   select_country: {
     type: 'SelectControl',
@@ -1105,44 +1085,23 @@ export const controls = {
   },
 
   x: {
-    type: 'SelectControl',
+    ...metric,
     label: t('X Axis'),
     description: t('Metric assigned to the [X] axis'),
     default: null,
-    validators: [v.nonEmpty],
-    optionRenderer: m => <MetricOption metric={m} showType />,
-    valueRenderer: m => <MetricOption metric={m} />,
-    valueKey: 'metric_name',
-    mapStateToProps: state => ({
-      options: (state.datasource) ? state.datasource.metrics : [],
-    }),
   },
 
   y: {
-    type: 'SelectControl',
+    ...metric,
     label: t('Y Axis'),
     default: null,
-    validators: [v.nonEmpty],
     description: t('Metric assigned to the [Y] axis'),
-    optionRenderer: m => <MetricOption metric={m} showType />,
-    valueRenderer: m => <MetricOption metric={m} />,
-    valueKey: 'metric_name',
-    mapStateToProps: state => ({
-      options: (state.datasource) ? state.datasource.metrics : [],
-    }),
   },
 
   size: {
-    type: 'SelectControl',
+    ...metric,
     label: t('Bubble Size'),
     default: null,
-    validators: [v.nonEmpty],
-    optionRenderer: m => <MetricOption metric={m} showType />,
-    valueRenderer: m => <MetricOption metric={m} />,
-    valueKey: 'metric_name',
-    mapStateToProps: state => ({
-      options: (state.datasource) ? state.datasource.metrics : [],
-    }),
   },
 
   url: {
diff --git a/superset/data/__init__.py b/superset/data/__init__.py
index 30b588f..d3d7da8 100644
--- a/superset/data/__init__.py
+++ b/superset/data/__init__.py
@@ -1168,10 +1168,10 @@ def load_multiformat_time_series_data():
     obj.fetch_metadata()
     tbl = obj
 
-    print("Creating some slices")
+    print("Creating Heatmap charts")
     for i, col in enumerate(tbl.columns):
         slice_data = {
-            "metric": 'count',
+            "metrics": ['count'],
             "granularity_sqla": col.column_name,
             "granularity_sqla": "day",
             "row_limit": config.get("ROW_LIMIT"),
diff --git a/superset/viz.py b/superset/viz.py
index 1e3fcb5..38e5680 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -43,6 +43,11 @@ from superset.utils import DTTM_ALIAS, JS_MAX_INTEGER, merge_extra_filters
 config = app.config
 stats_logger = config.get('STATS_LOGGER')
 
+METRIC_KEYS = [
+    'metric', 'metrics', 'percent_metrics', 'metric_2', 'secondary_metric',
+    'x', 'y', 'size',
+]
+
 
 class BaseViz(object):
 
@@ -66,13 +71,6 @@ class BaseViz(object):
         self.query = ''
         self.token = self.form_data.get(
             'token', 'token_' + uuid.uuid4().hex[:8])
-        metrics = self.form_data.get('metrics') or []
-        self.metrics = []
-        for metric in metrics:
-            if isinstance(metric, dict):
-                self.metrics.append(metric['label'])
-            else:
-                self.metrics.append(metric)
 
         self.groupby = self.form_data.get('groupby') or []
         self.time_shift = timedelta()
@@ -90,6 +88,29 @@ class BaseViz(object):
         self._any_cached_dttm = None
         self._extra_chart_data = None
 
+        self.process_metrics()
+
+    def process_metrics(self):
+        self.metric_dict = {}
+        fd = self.form_data
+        for mkey in METRIC_KEYS:
+            val = fd.get(mkey)
+            if val:
+                if not isinstance(val, list):
+                    val = [val]
+                for o in val:
+                    self.metric_dict[self.get_metric_label(o)] = o
+
+        # Cast to list needed to return serializable object in py3
+        self.all_metrics = list(self.metric_dict.values())
+        self.metric_labels = list(self.metric_dict.keys())
+
+    def get_metric_label(self, metric):
+        if isinstance(metric, string_types):
+            return metric
+        if isinstance(metric, dict):
+            return metric.get('label')
+
     @staticmethod
     def handle_js_int_overflow(data):
         for d in data.get('records', dict()):
@@ -202,7 +223,7 @@ class BaseViz(object):
         """Building a query object"""
         form_data = self.form_data
         gb = form_data.get('groupby') or []
-        metrics = form_data.get('metrics') or []
+        metrics = self.all_metrics or []
         columns = form_data.get('columns') or []
         groupby = []
         for o in gb + columns:
@@ -346,7 +367,7 @@ class BaseViz(object):
         and replace them with the use-provided inputs to bounds, which
         may we time-relative (as in "5 days ago" or "now").
         """
-        cache_dict = copy.deepcopy(query_obj)
+        cache_dict = copy.copy(query_obj)
 
         for k in ['from_dttm', 'to_dttm']:
             del cache_dict[k]
@@ -520,7 +541,7 @@ class TableViz(BaseViz):
                 'Choose either fields to [Group By] and [Metrics] or '
                 '[Columns], not both'))
 
-        sort_by = fd.get('timeseries_limit_metric')
+        sort_by = fd.get('timeseries_limit_metric') or []
         if fd.get('all_columns'):
             d['columns'] = fd.get('all_columns')
             d['groupby'] = []
@@ -535,7 +556,7 @@ class TableViz(BaseViz):
         if 'percent_metrics' in fd:
             d['metrics'] = d['metrics'] + list(filter(
                 lambda m: m not in d['metrics'],
-                fd['percent_metrics'],
+                fd['percent_metrics'] or [],
             ))
 
         d['is_timeseries'] = self.should_be_timeseries()
@@ -551,7 +572,8 @@ class TableViz(BaseViz):
             del df[DTTM_ALIAS]
 
         # Sum up and compute percentages for all percent metrics
-        percent_metrics = fd.get('percent_metrics', [])
+        percent_metrics = fd.get('percent_metrics') or []
+
         if len(percent_metrics):
             percent_metrics = list(filter(lambda m: m in df, percent_metrics))
             metric_sums = {
@@ -611,10 +633,10 @@ class TimeTableViz(BaseViz):
 
     def get_data(self, df):
         fd = self.form_data
-        values = self.metrics
         columns = None
+        values = self.metric_labels
         if fd.get('groupby'):
-            values = self.metrics[0]
+            values = self.metric_labels[0]
             columns = fd.get('groupby')
         pt = df.pivot_table(
             index=DTTM_ALIAS,
@@ -780,7 +802,7 @@ class CalHeatmapViz(BaseViz):
 
         data = {}
         records = df.to_dict('records')
-        for metric in self.metrics:
+        for metric in self.metric_labels:
             data[metric] = {
                 str(obj[DTTM_ALIAS].value / 10**9): obj.get(metric)
                 for obj in records
@@ -1109,7 +1131,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
             if (
                     isinstance(series_title, (list, tuple)) and
                     len(series_title) > 1 and
-                    len(self.metrics) == 1):
+                    len(self.metric_labels) == 1):
                 # Removing metric from series name if only one metric
                 series_title = series_title[1:]
             if title_suffix:
@@ -1393,10 +1415,11 @@ class DistributionPieViz(NVD3Viz):
     is_timeseries = False
 
     def get_data(self, df):
+        metric = self.metric_labels[0]
         df = df.pivot_table(
             index=self.groupby,
-            values=[self.metrics[0]])
-        df.sort_values(by=self.metrics[0], ascending=False, inplace=True)
+            values=[metric])
+        df.sort_values(by=metric, ascending=False, inplace=True)
         df = df.reset_index()
         df.columns = ['x', 'y']
         return df.to_dict(orient='records')
@@ -1468,14 +1491,15 @@ class DistributionBarViz(DistributionPieViz):
 
     def get_data(self, df):
         fd = self.form_data
+        metrics = self.metric_labels
 
-        row = df.groupby(self.groupby).sum()[self.metrics[0]].copy()
+        row = df.groupby(self.groupby).sum()[metrics[0]].copy()
         row.sort_values(ascending=False, inplace=True)
         columns = fd.get('columns') or []
         pt = df.pivot_table(
             index=self.groupby,
             columns=columns,
-            values=self.metrics)
+            values=metrics)
         if fd.get('contribution'):
             pt = pt.fillna(0)
             pt = pt.T
@@ -1487,7 +1511,7 @@ class DistributionBarViz(DistributionPieViz):
                 continue
             if isinstance(name, string_types):
                 series_title = name
-            elif len(self.metrics) > 1:
+            elif len(metrics) > 1:
                 series_title = ', '.join(name)
             else:
                 l = [str(s) for s in name[1:]]  # noqa: E741
@@ -1664,7 +1688,7 @@ class CountryMapViz(BaseViz):
     def get_data(self, df):
         fd = self.form_data
         cols = [fd.get('entity')]
-        metric = fd.get('metric')
+        metric = self.metric_labels[0]
         cols += [metric]
         ndf = df[cols]
         df = ndf
@@ -1836,7 +1860,7 @@ class HeatmapViz(BaseViz):
         fd = self.form_data
         x = fd.get('all_columns_x')
         y = fd.get('all_columns_y')
-        v = fd.get('metric')
+        v = self.metric_labels[0]
         if x == y:
             df.columns = ['x', 'y', 'v']
         else:

-- 
To stop receiving notification emails like this one, please contact
maximebeauchemin@apache.org.