You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by kg...@apache.org on 2021/11/05 15:07:14 UTC
[superset] branch master updated: fix(dashboard): Return columns
and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)
This is an automated email from the ASF dual-hosted git repository.
kgabryje pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new fa51b32 fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)
fa51b32 is described below
commit fa51b3234ed83a5f2910951f4cd2b1676a7b7d6d
Author: Kamil Gabryjelski <ka...@gmail.com>
AuthorDate: Fri Nov 5 16:05:48 2021 +0100
fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2 [ID-7] (#17287)
* fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2
* Refactor
* Fix test and lint
* Fix test
* Refactor
* Fix lint
---
superset/connectors/base/models.py | 27 ++++++++++++++++------
superset/examples/birth_names.py | 23 ++++++++++++++++++
superset/models/slice.py | 13 +++++++++++
tests/integration_tests/charts/api_tests.py | 6 ++---
tests/integration_tests/databases/api_tests.py | 2 +-
tests/integration_tests/model_tests.py | 32 +++++++++++++++++++++++---
6 files changed, 89 insertions(+), 14 deletions(-)
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 0bd9488..95e2054 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -282,7 +282,9 @@ class BaseDatasource(
"select_star": self.select_star,
}
- def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]:
+ def data_for_slices( # pylint: disable=too-many-locals
+ self, slices: List[Slice]
+ ) -> Dict[str, Any]:
"""
The representation of the datasource containing only the required data
to render the provided slices.
@@ -317,11 +319,23 @@ class BaseDatasource(
if "column" in filter_config
)
- column_names.update(
- column
- for column_param in COLUMN_FORM_DATA_PARAMS
- for column in utils.get_iterable(form_data.get(column_param) or [])
- )
+ # legacy charts don't have query_context charts
+ query_context = slc.get_query_context()
+ if query_context:
+ column_names.update(
+ [
+ column
+ for query in query_context.queries
+ for column in query.columns
+ ]
+ or []
+ )
+ else:
+ column_names.update(
+ column
+ for column_param in COLUMN_FORM_DATA_PARAMS
+ for column in utils.get_iterable(form_data.get(column_param) or [])
+ )
filtered_metrics = [
metric
@@ -639,7 +653,6 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
class BaseMetric(AuditMixinNullable, ImportExportMixin):
-
"""Interface for Metrics"""
__tablename__: Optional[str] = None # {connector_name}_metric
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index 4a4da1c..ef964e2 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -184,6 +184,13 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
"markup_type": "markdown",
}
+ default_query_context = {
+ "result_format": "json",
+ "result_type": "full",
+ "datasource": {"id": tbl.id, "type": "table",},
+ "queries": [{"columns": [], "metrics": [],},],
+ }
+
admin = get_admin_user()
if admin_owner:
slice_props = dict(
@@ -362,6 +369,22 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
metrics=metrics,
),
),
+ Slice(
+ **slice_props,
+ slice_name="Pivot Table v2",
+ viz_type="pivot_table_v2",
+ params=get_slice_json(
+ defaults,
+ viz_type="pivot_table_v2",
+ groupbyRows=["name"],
+ groupbyColumns=["state"],
+ metrics=[metric],
+ ),
+ query_context=get_slice_json(
+ default_query_context,
+ queries=[{"columns": ["name", "state"], "metrics": [metric],}],
+ ),
+ ),
]
misc_slices = [
Slice(
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 6bf05ff..f4d7195 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -40,6 +40,7 @@ from superset.utils.urls import get_url_path
from superset.viz import BaseViz, viz_types
if TYPE_CHECKING:
+ from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
metadata = Model.metadata # pylint: disable=no-member
@@ -247,6 +248,18 @@ class Slice( # pylint: disable=too-many-public-methods
update_time_range(form_data)
return form_data
+ def get_query_context(self) -> Optional["QueryContext"]:
+ # pylint: disable=import-outside-toplevel
+ from superset.common.query_context import QueryContext
+
+ if self.query_context:
+ try:
+ return QueryContext(**json.loads(self.query_context))
+ except json.decoder.JSONDecodeError as ex:
+ logger.error("Malformed json in slice's query context", exc_info=True)
+ logger.exception(ex)
+ return None
+
def get_explore_url(
self,
base_url: str = "/superset/explore",
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index 4c2eb02..f0c685b 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -790,7 +790,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(data["count"], 33)
+ self.assertEqual(data["count"], 34)
def test_get_charts_changed_on(self):
"""
@@ -1040,7 +1040,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
"""
Chart API: Test get charts filter
"""
- # Assuming we have 33 sample charts
+ # Assuming we have 34 sample charts
self.login(username="admin")
arguments = {"page_size": 10, "page": 0}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
@@ -1054,7 +1054,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(len(data["result"]), 3)
+ self.assertEqual(len(data["result"]), 4)
def test_get_charts_no_data_access(self):
"""
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index f989a6d..e13559e 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -1099,7 +1099,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(response["charts"]["count"], 33)
+ self.assertEqual(response["charts"]["count"], 34)
self.assertEqual(response["dashboards"]["count"], 3)
def test_get_database_related_objects_not_found(self):
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 56956c3..bc6349c 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -518,7 +518,7 @@ class TestSqlaTableModel(SupersetTestCase):
self.assertTrue("Metric 'invalid' does not exist", context.exception)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
- def test_data_for_slices(self):
+ def test_data_for_slices_with_no_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
@@ -532,9 +532,35 @@ class TestSqlaTableModel(SupersetTestCase):
assert len(data_for_slices["columns"]) == 1
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "gender"
- assert set(data_for_slices["verbose_map"].keys()) == set(
- ["__timestamp", "sum__num", "gender",]
+ assert set(data_for_slices["verbose_map"].keys()) == {
+ "__timestamp",
+ "sum__num",
+ "gender",
+ }
+
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+ def test_data_for_slices_with_query_context(self):
+ tbl = self.get_table(name="birth_names")
+ slc = (
+ metadata_db.session.query(Slice)
+ .filter_by(
+ datasource_id=tbl.id,
+ datasource_type=tbl.type,
+ slice_name="Pivot Table v2",
+ )
+ .first()
)
+ data_for_slices = tbl.data_for_slices([slc])
+ assert len(data_for_slices["metrics"]) == 1
+ assert len(data_for_slices["columns"]) == 2
+ assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
+ assert data_for_slices["columns"][0]["column_name"] == "name"
+ assert set(data_for_slices["verbose_map"].keys()) == {
+ "__timestamp",
+ "sum__num",
+ "name",
+ "state",
+ }
def test_literal_dttm_type_factory():