You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by kg...@apache.org on 2021/11/05 15:07:14 UTC

[superset] branch master updated: fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)

This is an automated email from the ASF dual-hosted git repository.

kgabryje pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new fa51b32  fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)
fa51b32 is described below

commit fa51b3234ed83a5f2910951f4cd2b1676a7b7d6d
Author: Kamil Gabryjelski <ka...@gmail.com>
AuthorDate: Fri Nov 5 16:05:48 2021 +0100

    fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)
    
    * fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2
    
    * Refactor
    
    * Fix test and lint
    
    * Fix test
    
    * Refactor
    
    * Fix lint
---
 superset/connectors/base/models.py             | 27 ++++++++++++++++------
 superset/examples/birth_names.py               | 23 ++++++++++++++++++
 superset/models/slice.py                       | 13 +++++++++++
 tests/integration_tests/charts/api_tests.py    |  6 ++---
 tests/integration_tests/databases/api_tests.py |  2 +-
 tests/integration_tests/model_tests.py         | 32 +++++++++++++++++++++++---
 6 files changed, 89 insertions(+), 14 deletions(-)

diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 0bd9488..95e2054 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -282,7 +282,9 @@ class BaseDatasource(
             "select_star": self.select_star,
         }
 
-    def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]:
+    def data_for_slices(  # pylint: disable=too-many-locals
+        self, slices: List[Slice]
+    ) -> Dict[str, Any]:
         """
         The representation of the datasource containing only the required data
         to render the provided slices.
@@ -317,11 +319,23 @@ class BaseDatasource(
                 if "column" in filter_config
             )
 
-            column_names.update(
-                column
-                for column_param in COLUMN_FORM_DATA_PARAMS
-                for column in utils.get_iterable(form_data.get(column_param) or [])
-            )
+            # legacy charts don't have query_context charts
+            query_context = slc.get_query_context()
+            if query_context:
+                column_names.update(
+                    [
+                        column
+                        for query in query_context.queries
+                        for column in query.columns
+                    ]
+                    or []
+                )
+            else:
+                column_names.update(
+                    column
+                    for column_param in COLUMN_FORM_DATA_PARAMS
+                    for column in utils.get_iterable(form_data.get(column_param) or [])
+                )
 
         filtered_metrics = [
             metric
@@ -639,7 +653,6 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
 
 
 class BaseMetric(AuditMixinNullable, ImportExportMixin):
-
     """Interface for Metrics"""
 
     __tablename__: Optional[str] = None  # {connector_name}_metric
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index 4a4da1c..ef964e2 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -184,6 +184,13 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
         "markup_type": "markdown",
     }
 
+    default_query_context = {
+        "result_format": "json",
+        "result_type": "full",
+        "datasource": {"id": tbl.id, "type": "table",},
+        "queries": [{"columns": [], "metrics": [],},],
+    }
+
     admin = get_admin_user()
     if admin_owner:
         slice_props = dict(
@@ -362,6 +369,22 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
                 metrics=metrics,
             ),
         ),
+        Slice(
+            **slice_props,
+            slice_name="Pivot Table v2",
+            viz_type="pivot_table_v2",
+            params=get_slice_json(
+                defaults,
+                viz_type="pivot_table_v2",
+                groupbyRows=["name"],
+                groupbyColumns=["state"],
+                metrics=[metric],
+            ),
+            query_context=get_slice_json(
+                default_query_context,
+                queries=[{"columns": ["name", "state"], "metrics": [metric],}],
+            ),
+        ),
     ]
     misc_slices = [
         Slice(
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 6bf05ff..f4d7195 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -40,6 +40,7 @@ from superset.utils.urls import get_url_path
 from superset.viz import BaseViz, viz_types
 
 if TYPE_CHECKING:
+    from superset.common.query_context import QueryContext
     from superset.connectors.base.models import BaseDatasource
 
 metadata = Model.metadata  # pylint: disable=no-member
@@ -247,6 +248,18 @@ class Slice(  # pylint: disable=too-many-public-methods
         update_time_range(form_data)
         return form_data
 
+    def get_query_context(self) -> Optional["QueryContext"]:
+        # pylint: disable=import-outside-toplevel
+        from superset.common.query_context import QueryContext
+
+        if self.query_context:
+            try:
+                return QueryContext(**json.loads(self.query_context))
+            except json.decoder.JSONDecodeError as ex:
+                logger.error("Malformed json in slice's query context", exc_info=True)
+                logger.exception(ex)
+        return None
+
     def get_explore_url(
         self,
         base_url: str = "/superset/explore",
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index 4c2eb02..f0c685b 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -790,7 +790,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
         rv = self.get_assert_metric(uri, "get_list")
         self.assertEqual(rv.status_code, 200)
         data = json.loads(rv.data.decode("utf-8"))
-        self.assertEqual(data["count"], 33)
+        self.assertEqual(data["count"], 34)
 
     def test_get_charts_changed_on(self):
         """
@@ -1040,7 +1040,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
         """
         Chart API: Test get charts filter
         """
-        # Assuming we have 33 sample charts
+        # Assuming we have 34 sample charts
         self.login(username="admin")
         arguments = {"page_size": 10, "page": 0}
         uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
@@ -1054,7 +1054,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
         rv = self.get_assert_metric(uri, "get_list")
         self.assertEqual(rv.status_code, 200)
         data = json.loads(rv.data.decode("utf-8"))
-        self.assertEqual(len(data["result"]), 3)
+        self.assertEqual(len(data["result"]), 4)
 
     def test_get_charts_no_data_access(self):
         """
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index f989a6d..e13559e 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -1099,7 +1099,7 @@ class TestDatabaseApi(SupersetTestCase):
         rv = self.get_assert_metric(uri, "related_objects")
         self.assertEqual(rv.status_code, 200)
         response = json.loads(rv.data.decode("utf-8"))
-        self.assertEqual(response["charts"]["count"], 33)
+        self.assertEqual(response["charts"]["count"], 34)
         self.assertEqual(response["dashboards"]["count"], 3)
 
     def test_get_database_related_objects_not_found(self):
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 56956c3..bc6349c 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -518,7 +518,7 @@ class TestSqlaTableModel(SupersetTestCase):
         self.assertTrue("Metric 'invalid' does not exist", context.exception)
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
-    def test_data_for_slices(self):
+    def test_data_for_slices_with_no_query_context(self):
         tbl = self.get_table(name="birth_names")
         slc = (
             metadata_db.session.query(Slice)
@@ -532,9 +532,35 @@ class TestSqlaTableModel(SupersetTestCase):
         assert len(data_for_slices["columns"]) == 1
         assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
         assert data_for_slices["columns"][0]["column_name"] == "gender"
-        assert set(data_for_slices["verbose_map"].keys()) == set(
-            ["__timestamp", "sum__num", "gender",]
+        assert set(data_for_slices["verbose_map"].keys()) == {
+            "__timestamp",
+            "sum__num",
+            "gender",
+        }
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_data_for_slices_with_query_context(self):
+        tbl = self.get_table(name="birth_names")
+        slc = (
+            metadata_db.session.query(Slice)
+            .filter_by(
+                datasource_id=tbl.id,
+                datasource_type=tbl.type,
+                slice_name="Pivot Table v2",
+            )
+            .first()
         )
+        data_for_slices = tbl.data_for_slices([slc])
+        assert len(data_for_slices["metrics"]) == 1
+        assert len(data_for_slices["columns"]) == 2
+        assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
+        assert data_for_slices["columns"][0]["column_name"] == "name"
+        assert set(data_for_slices["verbose_map"].keys()) == {
+            "__timestamp",
+            "sum__num",
+            "name",
+            "state",
+        }
 
 
 def test_literal_dttm_type_factory():