You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@superset.apache.org by GitBox <gi...@apache.org> on 2018/04/03 04:37:02 UTC

[GitHub] mistercrunch closed pull request #4646: [BugFix] Allowing limit ordering by post-aggregation metrics

mistercrunch closed pull request #4646: [BugFix] Allowing limit ordering by post-aggregation metrics
URL: https://github.com/apache/incubator-superset/pull/4646
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index d514a2f4ee..bd684a5988 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -35,7 +35,7 @@
 
 from superset import conf, db, import_util, security_manager, utils
 from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
-from superset.exceptions import MetricPermException
+from superset.exceptions import MetricPermException, SupersetException
 from superset.models.helpers import (
     AuditMixinNullable, ImportMixin, QueryResult, set_perm,
 )
@@ -44,6 +44,7 @@
 )
 
 DRUID_TZ = conf.get('DRUID_TZ')
+POST_AGG_TYPE = 'postagg'
 
 
 # Function wrapper because bound methods cannot
@@ -843,7 +844,7 @@ def find_postaggs_for(postagg_names, metrics_dict):
         """Return a list of metrics that are post aggregations"""
         postagg_metrics = [
             metrics_dict[name] for name in postagg_names
-            if metrics_dict[name].metric_type == 'postagg'
+            if metrics_dict[name].metric_type == POST_AGG_TYPE
         ]
         # Remove post aggregations that were found
         for postagg in postagg_metrics:
@@ -893,8 +894,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic
                     missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
         post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
 
-    @staticmethod
-    def metrics_and_post_aggs(metrics, metrics_dict):
+    @classmethod
+    def metrics_and_post_aggs(cls, metrics, metrics_dict):
         # Separate metrics into those that are aggregations
         # and those that are post aggregations
         saved_agg_names = set()
@@ -903,7 +904,7 @@ def metrics_and_post_aggs(metrics, metrics_dict):
         for metric in metrics:
             if utils.is_adhoc_metric(metric):
                 adhoc_agg_configs.append(metric)
-            elif metrics_dict[metric].metric_type != 'postagg':
+            elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
                 saved_agg_names.add(metric)
             else:
                 postagg_names.append(metric)
@@ -914,9 +915,10 @@ def metrics_and_post_aggs(metrics, metrics_dict):
         for postagg_name in postagg_names:
             postagg = metrics_dict[postagg_name]
             visited_postaggs.add(postagg_name)
-            DruidDatasource.resolve_postagg(
+            cls.resolve_postagg(
                 postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
-        return list(saved_agg_names), adhoc_agg_configs, post_aggs
+        aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
+        return aggs, post_aggs
 
     def values_for_column(self,
                           column_name,
@@ -982,16 +984,35 @@ def druid_type_from_adhoc_metric(adhoc_metric):
         else:
             return column_type + aggregate.capitalize()
 
-    def get_aggregations(self, saved_metrics, adhoc_metrics=[]):
+    @staticmethod
+    def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
+        """
+            Returns a dictionary of aggregation metric names to aggregation json objects
+
+            :param metrics_dict: dictionary of all the metrics
+            :param saved_metrics: list of saved metric names
+            :param adhoc_metrics: list of adhoc metric names
+            :raise SupersetException: if one or more metric names are not aggregations
+        """
         aggregations = OrderedDict()
-        for m in self.metrics:
-            if m.metric_name in saved_metrics:
-                aggregations[m.metric_name] = m.json_obj
+        invalid_metric_names = []
+        for metric_name in saved_metrics:
+            if metric_name in metrics_dict:
+                metric = metrics_dict[metric_name]
+                if metric.metric_type == POST_AGG_TYPE:
+                    invalid_metric_names.append(metric_name)
+                else:
+                    aggregations[metric_name] = metric.json_obj
+            else:
+                invalid_metric_names.append(metric_name)
+        if len(invalid_metric_names) > 0:
+            raise SupersetException(
+                _('Metric(s) {} must be aggregations.').format(invalid_metric_names))
         for adhoc_metric in adhoc_metrics:
             aggregations[adhoc_metric['label']] = {
                 'fieldName': adhoc_metric['column']['column_name'],
                 'fieldNames': [adhoc_metric['column']['column_name']],
-                'type': self.druid_type_from_adhoc_metric(adhoc_metric),
+                'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric),
                 'name': adhoc_metric['label'],
             }
         return aggregations
@@ -1087,11 +1108,10 @@ def run_query(  # noqa / druid
         metrics_dict = {m.metric_name: m for m in self.metrics}
         columns_dict = {c.column_name: c for c in self.columns}
 
-        saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+        aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics,
             metrics_dict)
 
-        aggregations = self.get_aggregations(saved_metrics, adhoc_metrics)
         self.check_restricted_metrics(aggregations)
 
         # the dimensions list with dimensionSpecs expanded
@@ -1143,7 +1163,15 @@ def run_query(  # noqa / druid
             pre_qry = deepcopy(qry)
             if timeseries_limit_metric:
                 order_by = timeseries_limit_metric
-                pre_qry['aggregations'] = self.get_aggregations([timeseries_limit_metric])
+                aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
+                    [timeseries_limit_metric],
+                    metrics_dict)
+                if phase == 1:
+                    pre_qry['aggregations'].update(aggs_dict)
+                    pre_qry['post_aggregations'].update(post_aggs_dict)
+                else:
+                    pre_qry['aggregations'] = aggs_dict
+                    pre_qry['post_aggregations'] = post_aggs_dict
             else:
                 order_by = list(qry['aggregations'].keys())[0]
             # Limit on the number of timeseries, doing a two-phases query
@@ -1193,6 +1221,15 @@ def run_query(  # noqa / druid
 
                 if timeseries_limit_metric:
                     order_by = timeseries_limit_metric
+                    aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
+                        [timeseries_limit_metric],
+                        metrics_dict)
+                    if phase == 1:
+                        pre_qry['aggregations'].update(aggs_dict)
+                        pre_qry['post_aggregations'].update(post_aggs_dict)
+                    else:
+                        pre_qry['aggregations'] = aggs_dict
+                        pre_qry['post_aggregations'] = post_aggs_dict
 
                 # Limit on the number of timeseries, doing a two-phases query
                 pre_qry['granularity'] = 'all'
diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py
index 22c1f38dc9..c47849433c 100644
--- a/tests/druid_func_tests.py
+++ b/tests/druid_func_tests.py
@@ -14,6 +14,7 @@
 from superset.connectors.druid.models import (
     DruidColumn, DruidDatasource, DruidMetric,
 )
+from superset.exceptions import SupersetException
 
 
 def mock_metric(metric_name, is_postagg=False):
@@ -157,9 +158,9 @@ def test_run_query_no_groupby(self):
         col1 = DruidColumn(column_name='col1')
         col2 = DruidColumn(column_name='col2')
         ds.columns = [col1, col2]
-        all_metrics = []
+        aggs = []
         post_aggs = ['some_agg']
-        ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
+        ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
         groupby = []
         metrics = ['metric1']
         ds.get_having_filters = Mock(return_value=[])
@@ -242,9 +243,9 @@ def test_run_query_single_groupby(self):
         col1 = DruidColumn(column_name='col1')
         col2 = DruidColumn(column_name='col2')
         ds.columns = [col1, col2]
-        all_metrics = ['metric1']
+        aggs = ['metric1']
         post_aggs = ['some_agg']
-        ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
+        ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
         groupby = ['col1']
         metrics = ['metric1']
         ds.get_having_filters = Mock(return_value=[])
@@ -316,9 +317,9 @@ def test_run_query_multiple_groupby(self):
         col1 = DruidColumn(column_name='col1')
         col2 = DruidColumn(column_name='col2')
         ds.columns = [col1, col2]
-        all_metrics = []
+        aggs = []
         post_aggs = ['some_agg']
-        ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
+        ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
         groupby = ['col1', 'col2']
         metrics = ['metric1']
         ds.get_having_filters = Mock(return_value=[])
@@ -512,10 +513,10 @@ def depends_on(index, fields):
         depends_on('I', ['H', 'K'])
         depends_on('J', 'K')
         depends_on('K', ['m8', 'm9'])
-        all_metrics, saved_metrics, postaggs = DruidDatasource.metrics_and_post_aggs(
+        aggs, postaggs = DruidDatasource.metrics_and_post_aggs(
             metrics, metrics_dict)
-        expected_metrics = set(all_metrics)
-        self.assertEqual(9, len(all_metrics))
+        expected_metrics = set(aggs.keys())
+        self.assertEqual(9, len(aggs))
         for i in range(1, 10):
             expected_metrics.remove('m' + str(i))
         self.assertEqual(0, len(expected_metrics))
@@ -593,45 +594,40 @@ def test_metrics_and_post_aggs(self):
         }
 
         metrics = ['some_sum']
-        saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+        saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics, metrics_dict)
 
-        assert saved_metrics == ['some_sum']
-        assert adhoc_metrics == []
+        assert set(saved_metrics.keys()) == {'some_sum'}
         assert post_aggs == {}
 
         metrics = [adhoc_metric]
-        saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+        saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics, metrics_dict)
 
-        assert saved_metrics == []
-        assert adhoc_metrics == [adhoc_metric]
+        assert set(saved_metrics.keys()) == set([adhoc_metric['label']])
         assert post_aggs == {}
 
         metrics = ['some_sum', adhoc_metric]
-        saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+        saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics, metrics_dict)
 
-        assert saved_metrics == ['some_sum']
-        assert adhoc_metrics == [adhoc_metric]
+        assert set(saved_metrics.keys()) == {'some_sum', adhoc_metric['label']}
         assert post_aggs == {}
 
         metrics = ['quantile_p95']
-        saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+        saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics, metrics_dict)
 
         result_postaggs = set(['quantile_p95'])
-        assert saved_metrics == ['a_histogram']
-        assert adhoc_metrics == []
+        assert set(saved_metrics.keys()) == {'a_histogram'}
         assert set(post_aggs.keys()) == result_postaggs
 
         metrics = ['aCustomPostAgg']
-        saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+        saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics, metrics_dict)
 
         result_postaggs = set(['aCustomPostAgg'])
-        assert saved_metrics == ['aCustomMetric']
-        assert adhoc_metrics == []
+        assert set(saved_metrics.keys()) == {'aCustomMetric'}
         assert set(post_aggs.keys()) == result_postaggs
 
     def test_druid_type_from_adhoc_metric(self):
@@ -663,3 +659,157 @@ def test_druid_type_from_adhoc_metric(self):
             'label': 'My Adhoc Metric',
         })
         assert(druid_type == 'cardinality')
+
+    def test_run_query_order_by_metrics(self):
+        client = Mock()
+        client.query_builder.last_query.query_dict = {'mock': 0}
+        from_dttm = Mock()
+        to_dttm = Mock()
+        ds = DruidDatasource(datasource_name='datasource')
+        ds.get_having_filters = Mock(return_value=[])
+        dim1 = DruidColumn(column_name='dim1')
+        dim2 = DruidColumn(column_name='dim2')
+        metrics_dict = {
+            'count1': DruidMetric(
+                metric_name='count1',
+                metric_type='count',
+                json=json.dumps({'type': 'count', 'name': 'count1'}),
+            ),
+            'sum1': DruidMetric(
+                metric_name='sum1',
+                metric_type='doubleSum',
+                json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}),
+            ),
+            'sum2': DruidMetric(
+                metric_name='sum2',
+                metric_type='doubleSum',
+                json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}),
+            ),
+            'div1': DruidMetric(
+                metric_name='div1',
+                metric_type='postagg',
+                json=json.dumps({
+                    'fn': '/',
+                    'type': 'arithmetic',
+                    'name': 'div1',
+                    'fields': [
+                        {
+                            'fieldName': 'sum1',
+                            'type': 'fieldAccess',
+                        },
+                        {
+                            'fieldName': 'sum2',
+                            'type': 'fieldAccess',
+                        },
+                    ],
+                }),
+            ),
+        }
+        ds.columns = [dim1, dim2]
+        ds.metrics = list(metrics_dict.values())
+
+        groupby = ['dim1']
+        metrics = ['count1']
+        granularity = 'all'
+        # get the counts of the top 5 'dim1's, order by 'sum1'
+        ds.run_query(
+            groupby, metrics, granularity, from_dttm, to_dttm,
+            timeseries_limit=5, timeseries_limit_metric='sum1',
+            client=client, order_desc=True, filter=[],
+        )
+        qry_obj = client.topn.call_args_list[0][1]
+        self.assertEqual('dim1', qry_obj['dimension'])
+        self.assertEqual('sum1', qry_obj['metric'])
+        aggregations = qry_obj['aggregations']
+        post_aggregations = qry_obj['post_aggregations']
+        self.assertEqual({'count1', 'sum1'}, set(aggregations.keys()))
+        self.assertEqual(set(), set(post_aggregations.keys()))
+
+        # get the counts of the top 5 'dim1's, order by 'div1'
+        ds.run_query(
+            groupby, metrics, granularity, from_dttm, to_dttm,
+            timeseries_limit=5, timeseries_limit_metric='div1',
+            client=client, order_desc=True, filter=[],
+        )
+        qry_obj = client.topn.call_args_list[1][1]
+        self.assertEqual('dim1', qry_obj['dimension'])
+        self.assertEqual('div1', qry_obj['metric'])
+        aggregations = qry_obj['aggregations']
+        post_aggregations = qry_obj['post_aggregations']
+        self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys()))
+        self.assertEqual({'div1'}, set(post_aggregations.keys()))
+
+        groupby = ['dim1', 'dim2']
+        # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1'
+        ds.run_query(
+            groupby, metrics, granularity, from_dttm, to_dttm,
+            timeseries_limit=5, timeseries_limit_metric='sum1',
+            client=client, order_desc=True, filter=[],
+        )
+        qry_obj = client.groupby.call_args_list[0][1]
+        self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions']))
+        self.assertEqual('sum1', qry_obj['limit_spec']['columns'][0]['dimension'])
+        aggregations = qry_obj['aggregations']
+        post_aggregations = qry_obj['post_aggregations']
+        self.assertEqual({'count1', 'sum1'}, set(aggregations.keys()))
+        self.assertEqual(set(), set(post_aggregations.keys()))
+
+        # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1'
+        ds.run_query(
+            groupby, metrics, granularity, from_dttm, to_dttm,
+            timeseries_limit=5, timeseries_limit_metric='div1',
+            client=client, order_desc=True, filter=[],
+        )
+        qry_obj = client.groupby.call_args_list[1][1]
+        self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions']))
+        self.assertEqual('div1', qry_obj['limit_spec']['columns'][0]['dimension'])
+        aggregations = qry_obj['aggregations']
+        post_aggregations = qry_obj['post_aggregations']
+        self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys()))
+        self.assertEqual({'div1'}, set(post_aggregations.keys()))
+
+    def test_get_aggregations(self):
+        ds = DruidDatasource(datasource_name='datasource')
+        metrics_dict = {
+            'sum1': DruidMetric(
+                metric_name='sum1',
+                metric_type='doubleSum',
+                json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}),
+            ),
+            'sum2': DruidMetric(
+                metric_name='sum2',
+                metric_type='doubleSum',
+                json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}),
+            ),
+            'div1': DruidMetric(
+                metric_name='div1',
+                metric_type='postagg',
+                json=json.dumps({
+                    'fn': '/',
+                    'type': 'arithmetic',
+                    'name': 'div1',
+                    'fields': [
+                        {
+                            'fieldName': 'sum1',
+                            'type': 'fieldAccess',
+                        },
+                        {
+                            'fieldName': 'sum2',
+                            'type': 'fieldAccess',
+                        },
+                    ],
+                }),
+            ),
+        }
+        metric_names = ['sum1', 'sum2']
+        aggs = ds.get_aggregations(metrics_dict, metric_names)
+        expected_agg = {name: metrics_dict[name].json_obj for name in metric_names}
+        self.assertEqual(expected_agg, aggs)
+
+        metric_names = ['sum1', 'col1']
+        self.assertRaises(
+            SupersetException, ds.get_aggregations, metrics_dict, metric_names)
+
+        metric_names = ['sum1', 'div1']
+        self.assertRaises(
+            SupersetException, ds.get_aggregations, metrics_dict, metric_names)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services