You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2018/08/29 04:04:09 UTC

[incubator-superset] branch master updated: [bugfix] 'DruidCluster' object has no attribute 'db_engine_spec' (#5765)

This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 135539c  [bugfix] 'DruidCluster' object has no attribute 'db_engine_spec' (#5765)
135539c is described below

commit 135539c109a90a5f13af8da0cc9842d3753b3bfe
Author: Maxime Beauchemin <ma...@gmail.com>
AuthorDate: Tue Aug 28 21:04:06 2018 -0700

    [bugfix] 'DruidCluster' object has no attribute 'db_engine_spec' (#5765)
    
    * [bugfix] 'DruidCluster' object has no attribute 'db_engine_spec'
    
    * Fix tests
---
 superset/viz.py     |   9 ++++-
 tests/base_tests.py |  19 ++++++++++
 tests/viz_tests.py  | 105 +++++++++++++++++++++++-----------------------------
 3 files changed, 73 insertions(+), 60 deletions(-)

diff --git a/superset/viz.py b/superset/viz.py
index 601411c..5701beb 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -120,9 +120,14 @@ class BaseViz(object):
     def get_metric_label(self, metric):
         if isinstance(metric, string_types):
             return metric
+
         if isinstance(metric, dict):
-            return self.datasource.database.db_engine_spec.mutate_expression_label(
-                metric.get('label'))
+            metric = metric.get('label')
+
+        if self.datasource.type == 'table':
+            db_engine_spec = self.datasource.database.db_engine_spec
+            metric = db_engine_spec.mutate_expression_label(metric)
+        return metric
 
     @staticmethod
     def handle_js_int_overflow(data):
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 35ac335..782cedd 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -11,6 +11,8 @@ import os
 import unittest
 
 from flask_appbuilder.security.sqla import models as ab_models
+from mock import Mock
+import pandas as pd
 
 from superset import app, cli, db, security_manager, utils
 from superset.connectors.druid.models import DruidCluster, DruidDatasource
@@ -147,6 +149,23 @@ class SupersetTestCase(unittest.TestCase):
         return db.session.query(DruidDatasource).filter_by(
             datasource_name=name).first()
 
+    def get_datasource_mock(self):
+        datasource = Mock()
+        results = Mock()
+        results.query = Mock()
+        results.status = Mock()
+        results.error_message = None
+        results.df = pd.DataFrame()
+        datasource.type = 'table'
+        datasource.query = Mock(return_value=results)
+        mock_dttm_col = Mock()
+        datasource.get_col = Mock(return_value=mock_dttm_col)
+        datasource.query = Mock(return_value=results)
+        datasource.database = Mock()
+        datasource.database.db_engine_spec = Mock()
+        datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
+        return datasource
+
     def get_resp(
             self, url, data=None, follow_redirects=True, raise_on_error=True):
         """Shortcut to get the parsed results while following redirects"""
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index ccca026..c1bc15e 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -5,7 +5,6 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 from datetime import datetime
-import unittest
 import uuid
 
 from mock import Mock, patch
@@ -15,10 +14,11 @@ from superset import app
 from superset.exceptions import SpatialException
 from superset.utils import DTTM_ALIAS
 import superset.viz as viz
+from .base_tests import SupersetTestCase
 from .utils import load_fixture
 
 
-class BaseVizTestCase(unittest.TestCase):
+class BaseVizTestCase(SupersetTestCase):
 
     def test_constructor_exception_no_datasource(self):
         form_data = {}
@@ -31,7 +31,7 @@ class BaseVizTestCase(unittest.TestCase):
             'viz_type': 'table',
             'token': '12345',
         }
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         test_viz = viz.BaseViz(datasource, form_data)
         self.assertEqual(
             test_viz.default_fillna,
@@ -39,16 +39,9 @@ class BaseVizTestCase(unittest.TestCase):
         )
 
     def test_get_df_returns_empty_df(self):
-        datasource = Mock()
-        datasource.type = 'table'
         form_data = {'dummy': 123}
         query_obj = {'granularity': 'day'}
-        results = Mock()
-        results.query = Mock()
-        results.status = Mock()
-        results.error_message = None
-        results.df = pd.DataFrame()
-        datasource.query = Mock(return_value=results)
+        datasource = self.get_datasource_mock()
         test_viz = viz.BaseViz(datasource, form_data)
         result = test_viz.get_df(query_obj)
         self.assertEqual(type(result), pd.DataFrame)
@@ -66,14 +59,20 @@ class BaseVizTestCase(unittest.TestCase):
         datasource.query = Mock(return_value=results)
         mock_dttm_col = Mock()
         datasource.get_col = Mock(return_value=mock_dttm_col)
+
         test_viz = viz.BaseViz(datasource, form_data)
         test_viz.df_metrics_to_num = Mock()
         test_viz.get_fillna_for_columns = Mock(return_value=0)
 
         results.df = pd.DataFrame(data={DTTM_ALIAS: ['1960-01-01 05:00:00']})
         datasource.offset = 0
+        mock_dttm_col = Mock()
+        datasource.get_col = Mock(return_value=mock_dttm_col)
         mock_dttm_col.python_date_format = 'epoch_ms'
         result = test_viz.get_df(query_obj)
+        print(result)
+        import logging
+        logging.info(result)
         pd.testing.assert_series_equal(
             result[DTTM_ALIAS],
             pd.Series([datetime(1960, 1, 1, 5, 0)], name=DTTM_ALIAS),
@@ -103,38 +102,28 @@ class BaseVizTestCase(unittest.TestCase):
         )
 
     def test_cache_timeout(self):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         datasource.cache_timeout = 0
         test_viz = viz.BaseViz(datasource, form_data={})
         self.assertEqual(0, test_viz.cache_timeout)
+
         datasource.cache_timeout = 156
         test_viz = viz.BaseViz(datasource, form_data={})
         self.assertEqual(156, test_viz.cache_timeout)
+
         datasource.cache_timeout = None
-        datasource.database = Mock()
         datasource.database.cache_timeout = 0
         self.assertEqual(0, test_viz.cache_timeout)
+
         datasource.database.cache_timeout = 1666
         self.assertEqual(1666, test_viz.cache_timeout)
+
         datasource.database.cache_timeout = None
         test_viz = viz.BaseViz(datasource, form_data={})
         self.assertEqual(app.config['CACHE_DEFAULT_TIMEOUT'], test_viz.cache_timeout)
 
 
-class TableVizTestCase(unittest.TestCase):
-
-    class DBEngineSpecMock:
-        @staticmethod
-        def mutate_expression_label(label):
-            return label
-
-    class DatabaseMock:
-        def __init__(self):
-            self.db_engine_spec = TableVizTestCase.DBEngineSpecMock()
-
-    class DatasourceMock:
-        def __init__(self):
-            self.database = TableVizTestCase.DatabaseMock()
+class TableVizTestCase(SupersetTestCase):
 
     def test_get_data_applies_percentage(self):
         form_data = {
@@ -151,7 +140,7 @@ class TableVizTestCase(unittest.TestCase):
                 'column': {'column_name': 'value1', 'type': 'DOUBLE'},
             }, 'count', 'avg__C'],
         }
-        datasource = TableVizTestCase.DatasourceMock()
+        datasource = self.get_datasource_mock()
         raw = {}
         raw['SUM(value1)'] = [15, 20, 25, 40]
         raw['avg__B'] = [10, 20, 5, 15]
@@ -227,7 +216,7 @@ class TableVizTestCase(unittest.TestCase):
                 },
             ],
         }
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         test_viz = viz.TableViz(datasource, form_data)
         query_obj = test_viz.query_obj()
         self.assertEqual(
@@ -265,7 +254,7 @@ class TableVizTestCase(unittest.TestCase):
             ],
             'having': 'SUM(value1) > 5',
         }
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         test_viz = viz.TableViz(datasource, form_data)
         query_obj = test_viz.query_obj()
         self.assertEqual(
@@ -281,7 +270,7 @@ class TableVizTestCase(unittest.TestCase):
 
     @patch('superset.viz.BaseViz.query_obj')
     def test_query_obj_merges_percent_metrics(self, super_query_obj):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {
             'percent_metrics': ['sum__A', 'avg__B', 'max__Y'],
             'metrics': ['sum__A', 'count', 'avg__C'],
@@ -299,7 +288,7 @@ class TableVizTestCase(unittest.TestCase):
 
     @patch('superset.viz.BaseViz.query_obj')
     def test_query_obj_throws_columns_and_metrics(self, super_query_obj):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {
             'all_columns': ['A', 'B'],
             'metrics': ['x', 'y'],
@@ -316,7 +305,7 @@ class TableVizTestCase(unittest.TestCase):
 
     @patch('superset.viz.BaseViz.query_obj')
     def test_query_obj_merges_all_columns(self, super_query_obj):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {
             'all_columns': ['colA', 'colB', 'colC'],
             'order_by_cols': ['["colA", "colB"]', '["colC"]'],
@@ -333,7 +322,7 @@ class TableVizTestCase(unittest.TestCase):
 
     @patch('superset.viz.BaseViz.query_obj')
     def test_query_obj_uses_sortby(self, super_query_obj):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {
             'timeseries_limit_metric': '__time__',
             'order_desc': False,
@@ -351,20 +340,20 @@ class TableVizTestCase(unittest.TestCase):
         )], query_obj['orderby'])
 
     def test_should_be_timeseries_raises_when_no_granularity(self):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {'include_time': True}
         test_viz = viz.TableViz(datasource, form_data)
         with self.assertRaises(Exception):
             test_viz.should_be_timeseries()
 
 
-class PairedTTestTestCase(unittest.TestCase):
+class PairedTTestTestCase(SupersetTestCase):
     def test_get_data_transforms_dataframe(self):
         form_data = {
             'groupby': ['groupA', 'groupB', 'groupC'],
             'metrics': ['metric1', 'metric2', 'metric3'],
         }
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         # Test data
         raw = {}
         raw[DTTM_ALIAS] = [100, 200, 300, 100, 200, 300, 100, 200, 300]
@@ -456,7 +445,7 @@ class PairedTTestTestCase(unittest.TestCase):
             'groupby': [],
             'metrics': ['', None],
         }
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         # Test data
         raw = {}
         raw[DTTM_ALIAS] = [100, 200, 300]
@@ -490,11 +479,11 @@ class PairedTTestTestCase(unittest.TestCase):
         self.assertEqual(data, expected)
 
 
-class PartitionVizTestCase(unittest.TestCase):
+class PartitionVizTestCase(SupersetTestCase):
 
     @patch('superset.viz.BaseViz.query_obj')
     def test_query_obj_time_series_option(self, super_query_obj):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {}
         test_viz = viz.PartitionViz(datasource, form_data)
         super_query_obj.return_value = {}
@@ -715,7 +704,7 @@ class PartitionVizTestCase(unittest.TestCase):
         self.assertEqual(7, len(test_viz.nest_values.mock_calls))
 
 
-class RoseVisTestCase(unittest.TestCase):
+class RoseVisTestCase(SupersetTestCase):
 
     def test_rose_vis_get_data(self):
         raw = {}
@@ -755,14 +744,14 @@ class RoseVisTestCase(unittest.TestCase):
         self.assertEqual(expected, res)
 
 
-class TimeSeriesTableVizTestCase(unittest.TestCase):
+class TimeSeriesTableVizTestCase(SupersetTestCase):
 
     def test_get_data_metrics(self):
         form_data = {
             'metrics': ['sum__A', 'count'],
             'groupby': [],
         }
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         raw = {}
         t1 = pd.Timestamp('2000')
         t2 = pd.Timestamp('2002')
@@ -792,7 +781,7 @@ class TimeSeriesTableVizTestCase(unittest.TestCase):
             'metrics': ['sum__A'],
             'groupby': ['groupby1'],
         }
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         raw = {}
         t1 = pd.Timestamp('2000')
         t2 = pd.Timestamp('2002')
@@ -821,7 +810,7 @@ class TimeSeriesTableVizTestCase(unittest.TestCase):
 
     @patch('superset.viz.BaseViz.query_obj')
     def test_query_obj_throws_metrics_and_groupby(self, super_query_obj):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {
             'groupby': ['a'],
         }
@@ -835,11 +824,11 @@ class TimeSeriesTableVizTestCase(unittest.TestCase):
             test_viz.query_obj()
 
 
-class BaseDeckGLVizTestCase(unittest.TestCase):
+class BaseDeckGLVizTestCase(SupersetTestCase):
 
     def test_get_metrics(self):
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
         result = test_viz_deckgl.get_metrics()
         assert result == [form_data.get('size')]
@@ -851,7 +840,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
 
     def test_scatterviz_get_metrics(self):
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
 
         form_data = {}
         test_viz_deckgl = viz.DeckScatterViz(datasource, form_data)
@@ -867,7 +856,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
 
     def test_get_js_columns(self):
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         mock_d = {
             'a': 'dummy1',
             'b': 'dummy2',
@@ -881,7 +870,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
     def test_get_properties(self):
         mock_d = {}
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
 
         with self.assertRaises(NotImplementedError) as context:
@@ -891,7 +880,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
 
     def test_process_spatial_query_obj(self):
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         mock_key = 'spatial_key'
         mock_gb = []
         test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
@@ -917,7 +906,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
             },
         }
 
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         expected_results = {
             'latlong_key': ['lon', 'lat'],
             'delimited_key': ['lonlat'],
@@ -931,7 +920,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
 
     def test_geojson_query_obj(self):
         form_data = load_fixture('deck_geojson_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         test_viz_deckgl = viz.DeckGeoJson(datasource, form_data)
         results = test_viz_deckgl.query_obj()
 
@@ -941,7 +930,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
 
     def test_parse_coordinates(self):
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         viz_instance = viz.BaseDeckGLViz(datasource, form_data)
 
         coord = viz_instance.parse_coordinates('1.23, 3.21')
@@ -956,7 +945,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
 
     def test_parse_coordinates_raises(self):
         form_data = load_fixture('deck_path_form_data.json')
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         test_viz_deckgl = viz.BaseDeckGLViz(datasource, form_data)
 
         with self.assertRaises(SpatialException):
@@ -984,7 +973,7 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
             },
         }
 
-        datasource = {'type': 'table'}
+        datasource = self.get_datasource_mock()
         expected_results = {
             'latlong_key': [{
                 'clause': 'WHERE',
@@ -1027,10 +1016,10 @@ class BaseDeckGLVizTestCase(unittest.TestCase):
             assert expected_results.get(mock_key) == adhoc_filters
 
 
-class TimeSeriesVizTestCase(unittest.TestCase):
+class TimeSeriesVizTestCase(SupersetTestCase):
 
     def test_timeseries_unicode_data(self):
-        datasource = Mock()
+        datasource = self.get_datasource_mock()
         form_data = {
             'groupby': ['name'],
             'metrics': ['sum__payout'],