You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@superset.apache.org by GitBox <gi...@apache.org> on 2018/04/03 04:37:02 UTC
[GitHub] mistercrunch closed pull request #4646: [BugFix] Allowing limit ordering by post-aggregation metrics
mistercrunch closed pull request #4646: [BugFix] Allowing limit ordering by post-aggregation metrics
URL: https://github.com/apache/incubator-superset/pull/4646
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index d514a2f4ee..bd684a5988 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -35,7 +35,7 @@
from superset import conf, db, import_util, security_manager, utils
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
-from superset.exceptions import MetricPermException
+from superset.exceptions import MetricPermException, SupersetException
from superset.models.helpers import (
AuditMixinNullable, ImportMixin, QueryResult, set_perm,
)
@@ -44,6 +44,7 @@
)
DRUID_TZ = conf.get('DRUID_TZ')
+POST_AGG_TYPE = 'postagg'
# Function wrapper because bound methods cannot
@@ -843,7 +844,7 @@ def find_postaggs_for(postagg_names, metrics_dict):
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name] for name in postagg_names
- if metrics_dict[name].metric_type == 'postagg'
+ if metrics_dict[name].metric_type == POST_AGG_TYPE
]
# Remove post aggregations that were found
for postagg in postagg_metrics:
@@ -893,8 +894,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
- @staticmethod
- def metrics_and_post_aggs(metrics, metrics_dict):
+ @classmethod
+ def metrics_and_post_aggs(cls, metrics, metrics_dict):
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
@@ -903,7 +904,7 @@ def metrics_and_post_aggs(metrics, metrics_dict):
for metric in metrics:
if utils.is_adhoc_metric(metric):
adhoc_agg_configs.append(metric)
- elif metrics_dict[metric].metric_type != 'postagg':
+ elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
@@ -914,9 +915,10 @@ def metrics_and_post_aggs(metrics, metrics_dict):
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
- DruidDatasource.resolve_postagg(
+ cls.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
- return list(saved_agg_names), adhoc_agg_configs, post_aggs
+ aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
+ return aggs, post_aggs
def values_for_column(self,
column_name,
@@ -982,16 +984,35 @@ def druid_type_from_adhoc_metric(adhoc_metric):
else:
return column_type + aggregate.capitalize()
- def get_aggregations(self, saved_metrics, adhoc_metrics=[]):
+ @staticmethod
+ def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
+ """
+ Returns a dictionary of aggregation metric names to aggregation json objects
+
+ :param metrics_dict: dictionary of all the metrics
+ :param saved_metrics: list of saved metric names
+ :param adhoc_metrics: list of adhoc metric names
+ :raise SupersetException: if one or more metric names are not aggregations
+ """
aggregations = OrderedDict()
- for m in self.metrics:
- if m.metric_name in saved_metrics:
- aggregations[m.metric_name] = m.json_obj
+ invalid_metric_names = []
+ for metric_name in saved_metrics:
+ if metric_name in metrics_dict:
+ metric = metrics_dict[metric_name]
+ if metric.metric_type == POST_AGG_TYPE:
+ invalid_metric_names.append(metric_name)
+ else:
+ aggregations[metric_name] = metric.json_obj
+ else:
+ invalid_metric_names.append(metric_name)
+ if len(invalid_metric_names) > 0:
+ raise SupersetException(
+ _('Metric(s) {} must be aggregations.').format(invalid_metric_names))
for adhoc_metric in adhoc_metrics:
aggregations[adhoc_metric['label']] = {
'fieldName': adhoc_metric['column']['column_name'],
'fieldNames': [adhoc_metric['column']['column_name']],
- 'type': self.druid_type_from_adhoc_metric(adhoc_metric),
+ 'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric),
'name': adhoc_metric['label'],
}
return aggregations
@@ -1087,11 +1108,10 @@ def run_query( # noqa / druid
metrics_dict = {m.metric_name: m for m in self.metrics}
columns_dict = {c.column_name: c for c in self.columns}
- saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)
- aggregations = self.get_aggregations(saved_metrics, adhoc_metrics)
self.check_restricted_metrics(aggregations)
# the dimensions list with dimensionSpecs expanded
@@ -1143,7 +1163,15 @@ def run_query( # noqa / druid
pre_qry = deepcopy(qry)
if timeseries_limit_metric:
order_by = timeseries_limit_metric
- pre_qry['aggregations'] = self.get_aggregations([timeseries_limit_metric])
+ aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
+ [timeseries_limit_metric],
+ metrics_dict)
+ if phase == 1:
+ pre_qry['aggregations'].update(aggs_dict)
+ pre_qry['post_aggregations'].update(post_aggs_dict)
+ else:
+ pre_qry['aggregations'] = aggs_dict
+ pre_qry['post_aggregations'] = post_aggs_dict
else:
order_by = list(qry['aggregations'].keys())[0]
# Limit on the number of timeseries, doing a two-phases query
@@ -1193,6 +1221,15 @@ def run_query( # noqa / druid
if timeseries_limit_metric:
order_by = timeseries_limit_metric
+ aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
+ [timeseries_limit_metric],
+ metrics_dict)
+ if phase == 1:
+ pre_qry['aggregations'].update(aggs_dict)
+ pre_qry['post_aggregations'].update(post_aggs_dict)
+ else:
+ pre_qry['aggregations'] = aggs_dict
+ pre_qry['post_aggregations'] = post_aggs_dict
# Limit on the number of timeseries, doing a two-phases query
pre_qry['granularity'] = 'all'
diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py
index 22c1f38dc9..c47849433c 100644
--- a/tests/druid_func_tests.py
+++ b/tests/druid_func_tests.py
@@ -14,6 +14,7 @@
from superset.connectors.druid.models import (
DruidColumn, DruidDatasource, DruidMetric,
)
+from superset.exceptions import SupersetException
def mock_metric(metric_name, is_postagg=False):
@@ -157,9 +158,9 @@ def test_run_query_no_groupby(self):
col1 = DruidColumn(column_name='col1')
col2 = DruidColumn(column_name='col2')
ds.columns = [col1, col2]
- all_metrics = []
+ aggs = []
post_aggs = ['some_agg']
- ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
+ ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
groupby = []
metrics = ['metric1']
ds.get_having_filters = Mock(return_value=[])
@@ -242,9 +243,9 @@ def test_run_query_single_groupby(self):
col1 = DruidColumn(column_name='col1')
col2 = DruidColumn(column_name='col2')
ds.columns = [col1, col2]
- all_metrics = ['metric1']
+ aggs = ['metric1']
post_aggs = ['some_agg']
- ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
+ ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
groupby = ['col1']
metrics = ['metric1']
ds.get_having_filters = Mock(return_value=[])
@@ -316,9 +317,9 @@ def test_run_query_multiple_groupby(self):
col1 = DruidColumn(column_name='col1')
col2 = DruidColumn(column_name='col2')
ds.columns = [col1, col2]
- all_metrics = []
+ aggs = []
post_aggs = ['some_agg']
- ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
+ ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
groupby = ['col1', 'col2']
metrics = ['metric1']
ds.get_having_filters = Mock(return_value=[])
@@ -512,10 +513,10 @@ def depends_on(index, fields):
depends_on('I', ['H', 'K'])
depends_on('J', 'K')
depends_on('K', ['m8', 'm9'])
- all_metrics, saved_metrics, postaggs = DruidDatasource.metrics_and_post_aggs(
+ aggs, postaggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
- expected_metrics = set(all_metrics)
- self.assertEqual(9, len(all_metrics))
+ expected_metrics = set(aggs.keys())
+ self.assertEqual(9, len(aggs))
for i in range(1, 10):
expected_metrics.remove('m' + str(i))
self.assertEqual(0, len(expected_metrics))
@@ -593,45 +594,40 @@ def test_metrics_and_post_aggs(self):
}
metrics = ['some_sum']
- saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
- assert saved_metrics == ['some_sum']
- assert adhoc_metrics == []
+ assert set(saved_metrics.keys()) == {'some_sum'}
assert post_aggs == {}
metrics = [adhoc_metric]
- saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
- assert saved_metrics == []
- assert adhoc_metrics == [adhoc_metric]
+ assert set(saved_metrics.keys()) == set([adhoc_metric['label']])
assert post_aggs == {}
metrics = ['some_sum', adhoc_metric]
- saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
- assert saved_metrics == ['some_sum']
- assert adhoc_metrics == [adhoc_metric]
+ assert set(saved_metrics.keys()) == {'some_sum', adhoc_metric['label']}
assert post_aggs == {}
metrics = ['quantile_p95']
- saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
result_postaggs = set(['quantile_p95'])
- assert saved_metrics == ['a_histogram']
- assert adhoc_metrics == []
+ assert set(saved_metrics.keys()) == {'a_histogram'}
assert set(post_aggs.keys()) == result_postaggs
metrics = ['aCustomPostAgg']
- saved_metrics, adhoc_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+ saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics, metrics_dict)
result_postaggs = set(['aCustomPostAgg'])
- assert saved_metrics == ['aCustomMetric']
- assert adhoc_metrics == []
+ assert set(saved_metrics.keys()) == {'aCustomMetric'}
assert set(post_aggs.keys()) == result_postaggs
def test_druid_type_from_adhoc_metric(self):
@@ -663,3 +659,157 @@ def test_druid_type_from_adhoc_metric(self):
'label': 'My Adhoc Metric',
})
assert(druid_type == 'cardinality')
+
+ def test_run_query_order_by_metrics(self):
+ client = Mock()
+ client.query_builder.last_query.query_dict = {'mock': 0}
+ from_dttm = Mock()
+ to_dttm = Mock()
+ ds = DruidDatasource(datasource_name='datasource')
+ ds.get_having_filters = Mock(return_value=[])
+ dim1 = DruidColumn(column_name='dim1')
+ dim2 = DruidColumn(column_name='dim2')
+ metrics_dict = {
+ 'count1': DruidMetric(
+ metric_name='count1',
+ metric_type='count',
+ json=json.dumps({'type': 'count', 'name': 'count1'}),
+ ),
+ 'sum1': DruidMetric(
+ metric_name='sum1',
+ metric_type='doubleSum',
+ json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}),
+ ),
+ 'sum2': DruidMetric(
+ metric_name='sum2',
+ metric_type='doubleSum',
+ json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}),
+ ),
+ 'div1': DruidMetric(
+ metric_name='div1',
+ metric_type='postagg',
+ json=json.dumps({
+ 'fn': '/',
+ 'type': 'arithmetic',
+ 'name': 'div1',
+ 'fields': [
+ {
+ 'fieldName': 'sum1',
+ 'type': 'fieldAccess',
+ },
+ {
+ 'fieldName': 'sum2',
+ 'type': 'fieldAccess',
+ },
+ ],
+ }),
+ ),
+ }
+ ds.columns = [dim1, dim2]
+ ds.metrics = list(metrics_dict.values())
+
+ groupby = ['dim1']
+ metrics = ['count1']
+ granularity = 'all'
+ # get the counts of the top 5 'dim1's, order by 'sum1'
+ ds.run_query(
+ groupby, metrics, granularity, from_dttm, to_dttm,
+ timeseries_limit=5, timeseries_limit_metric='sum1',
+ client=client, order_desc=True, filter=[],
+ )
+ qry_obj = client.topn.call_args_list[0][1]
+ self.assertEqual('dim1', qry_obj['dimension'])
+ self.assertEqual('sum1', qry_obj['metric'])
+ aggregations = qry_obj['aggregations']
+ post_aggregations = qry_obj['post_aggregations']
+ self.assertEqual({'count1', 'sum1'}, set(aggregations.keys()))
+ self.assertEqual(set(), set(post_aggregations.keys()))
+
+ # get the counts of the top 5 'dim1's, order by 'div1'
+ ds.run_query(
+ groupby, metrics, granularity, from_dttm, to_dttm,
+ timeseries_limit=5, timeseries_limit_metric='div1',
+ client=client, order_desc=True, filter=[],
+ )
+ qry_obj = client.topn.call_args_list[1][1]
+ self.assertEqual('dim1', qry_obj['dimension'])
+ self.assertEqual('div1', qry_obj['metric'])
+ aggregations = qry_obj['aggregations']
+ post_aggregations = qry_obj['post_aggregations']
+ self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys()))
+ self.assertEqual({'div1'}, set(post_aggregations.keys()))
+
+ groupby = ['dim1', 'dim2']
+ # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1'
+ ds.run_query(
+ groupby, metrics, granularity, from_dttm, to_dttm,
+ timeseries_limit=5, timeseries_limit_metric='sum1',
+ client=client, order_desc=True, filter=[],
+ )
+ qry_obj = client.groupby.call_args_list[0][1]
+ self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions']))
+ self.assertEqual('sum1', qry_obj['limit_spec']['columns'][0]['dimension'])
+ aggregations = qry_obj['aggregations']
+ post_aggregations = qry_obj['post_aggregations']
+ self.assertEqual({'count1', 'sum1'}, set(aggregations.keys()))
+ self.assertEqual(set(), set(post_aggregations.keys()))
+
+ # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1'
+ ds.run_query(
+ groupby, metrics, granularity, from_dttm, to_dttm,
+ timeseries_limit=5, timeseries_limit_metric='div1',
+ client=client, order_desc=True, filter=[],
+ )
+ qry_obj = client.groupby.call_args_list[1][1]
+ self.assertEqual({'dim1', 'dim2'}, set(qry_obj['dimensions']))
+ self.assertEqual('div1', qry_obj['limit_spec']['columns'][0]['dimension'])
+ aggregations = qry_obj['aggregations']
+ post_aggregations = qry_obj['post_aggregations']
+ self.assertEqual({'count1', 'sum1', 'sum2'}, set(aggregations.keys()))
+ self.assertEqual({'div1'}, set(post_aggregations.keys()))
+
+ def test_get_aggregations(self):
+ ds = DruidDatasource(datasource_name='datasource')
+ metrics_dict = {
+ 'sum1': DruidMetric(
+ metric_name='sum1',
+ metric_type='doubleSum',
+ json=json.dumps({'type': 'doubleSum', 'name': 'sum1'}),
+ ),
+ 'sum2': DruidMetric(
+ metric_name='sum2',
+ metric_type='doubleSum',
+ json=json.dumps({'type': 'doubleSum', 'name': 'sum2'}),
+ ),
+ 'div1': DruidMetric(
+ metric_name='div1',
+ metric_type='postagg',
+ json=json.dumps({
+ 'fn': '/',
+ 'type': 'arithmetic',
+ 'name': 'div1',
+ 'fields': [
+ {
+ 'fieldName': 'sum1',
+ 'type': 'fieldAccess',
+ },
+ {
+ 'fieldName': 'sum2',
+ 'type': 'fieldAccess',
+ },
+ ],
+ }),
+ ),
+ }
+ metric_names = ['sum1', 'sum2']
+ aggs = ds.get_aggregations(metrics_dict, metric_names)
+ expected_agg = {name: metrics_dict[name].json_obj for name in metric_names}
+ self.assertEqual(expected_agg, aggs)
+
+ metric_names = ['sum1', 'col1']
+ self.assertRaises(
+ SupersetException, ds.get_aggregations, metrics_dict, metric_names)
+
+ metric_names = ['sum1', 'div1']
+ self.assertRaises(
+ SupersetException, ds.get_aggregations, metrics_dict, metric_names)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services