You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/07/28 07:59:55 UTC
[incubator-superset] branch 0.37 updated: feat: support non-numeric
columns in pivot table (#10389)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/0.37 by this push:
new 0f3670e feat: support non-numeric columns in pivot table (#10389)
0f3670e is described below
commit 0f3670e1af8778aefefa179f02453cec1766e6e5
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Jul 28 10:40:53 2020 +0300
feat: support non-numeric columns in pivot table (#10389)
* fix: support non-numeric columns in pivot table
* bump package and add unit tests
* mypy
---
superset/viz.py | 39 +++++++++++++++++++++++++++++++--------
tests/viz_tests.py | 38 ++++++++++++++++++++++++++++++++++++++
2 files changed, 69 insertions(+), 8 deletions(-)
diff --git a/superset/viz.py b/superset/viz.py
index 2067f7c..8cb2aa6 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -29,7 +29,18 @@ import uuid
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from itertools import product
-from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TYPE_CHECKING,
+ Union,
+)
import dataclasses
import geohash
@@ -736,6 +747,7 @@ class PivotTableViz(BaseViz):
verbose_name = _("Pivot Table")
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
is_timeseries = False
+ enforce_numerical_metrics = False
def query_obj(self) -> QueryObjectDict:
d = super().query_obj()
@@ -766,6 +778,18 @@ class PivotTableViz(BaseViz):
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
return d
+ @staticmethod
+ def get_aggfunc(
+ metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
+ ) -> Union[str, Callable[[Any], Any]]:
+ aggfunc = form_data.get("pandas_aggfunc") or "sum"
+ if pd.api.types.is_numeric_dtype(df[metric]):
+ # Ensure that Pandas's sum function mimics that of SQL.
+ if aggfunc == "sum":
+ return lambda x: x.sum(min_count=1)
+ # only min and max work properly for non-numerics
+ return aggfunc if aggfunc in ("min", "max") else "max"
+
def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
@@ -773,22 +797,21 @@ class PivotTableViz(BaseViz):
if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
- aggfunc = self.form_data.get("pandas_aggfunc") or "sum"
-
- # Ensure that Pandas's sum function mimics that of SQL.
- if aggfunc == "sum":
- aggfunc = lambda x: x.sum(min_count=1)
+ metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
+ aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
+ for metric in metrics:
+ aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)
groupby = self.form_data.get("groupby")
columns = self.form_data.get("columns")
if self.form_data.get("transpose_pivot"):
groupby, columns = columns, groupby
- metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
+
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
- aggfunc=aggfunc,
+ aggfunc=aggfuncs,
margins=self.form_data.get("pivot_margins"),
)
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 8290fbf..17e43d8 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -1284,3 +1284,41 @@ class TestBigNumberViz(SupersetTestCase):
)
data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
assert np.isnan(data[2]["y"])
+
+
+class TestPivotTableViz(SupersetTestCase):
+ df = pd.DataFrame(
+ data={
+ "intcol": [1, 2, 3, None],
+ "floatcol": [0.1, 0.2, 0.3, None],
+ "strcol": ["a", "b", "c", None],
+ }
+ )
+
+ def test_get_aggfunc_numeric(self):
+ # is a sum function
+ func = viz.PivotTableViz.get_aggfunc("intcol", self.df, {})
+ assert hasattr(func, "__call__")
+ assert func(self.df["intcol"]) == 6
+
+ assert (
+ viz.PivotTableViz.get_aggfunc("intcol", self.df, {"pandas_aggfunc": "min"})
+ == "min"
+ )
+ assert (
+ viz.PivotTableViz.get_aggfunc(
+ "floatcol", self.df, {"pandas_aggfunc": "max"}
+ )
+ == "max"
+ )
+
+ def test_get_aggfunc_non_numeric(self):
+ assert viz.PivotTableViz.get_aggfunc("strcol", self.df, {}) == "max"
+ assert (
+ viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "sum"})
+ == "max"
+ )
+ assert (
+ viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"})
+ == "min"
+ )