You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/07/28 07:59:55 UTC

[incubator-superset] branch 0.37 updated: feat: support non-numeric columns in pivot table (#10389)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/0.37 by this push:
     new 0f3670e  feat: support non-numeric columns in pivot table (#10389)
0f3670e is described below

commit 0f3670e1af8778aefefa179f02453cec1766e6e5
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Jul 28 10:40:53 2020 +0300

    feat: support non-numeric columns in pivot table (#10389)
    
    * fix: support non-numeric columns in pivot table
    
    * bump package and add unit tests
    
    * mypy
---
 superset/viz.py    | 39 +++++++++++++++++++++++++++++++--------
 tests/viz_tests.py | 38 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 69 insertions(+), 8 deletions(-)

diff --git a/superset/viz.py b/superset/viz.py
index 2067f7c..8cb2aa6 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -29,7 +29,18 @@ import uuid
 from collections import defaultdict, OrderedDict
 from datetime import datetime, timedelta
 from itertools import product
-from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TYPE_CHECKING,
+    Union,
+)
 
 import dataclasses
 import geohash
@@ -736,6 +747,7 @@ class PivotTableViz(BaseViz):
     verbose_name = _("Pivot Table")
     credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
     is_timeseries = False
+    enforce_numerical_metrics = False
 
     def query_obj(self) -> QueryObjectDict:
         d = super().query_obj()
@@ -766,6 +778,18 @@ class PivotTableViz(BaseViz):
             raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
         return d
 
+    @staticmethod
+    def get_aggfunc(
+        metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
+    ) -> Union[str, Callable[[Any], Any]]:
+        aggfunc = form_data.get("pandas_aggfunc") or "sum"
+        if pd.api.types.is_numeric_dtype(df[metric]):
+            # Ensure that Pandas's sum function mimics that of SQL.
+            if aggfunc == "sum":
+                return lambda x: x.sum(min_count=1)
+        # only min and max work properly for non-numerics
+        return aggfunc if aggfunc in ("min", "max") else "max"
+
     def get_data(self, df: pd.DataFrame) -> VizData:
         if df.empty:
             return None
@@ -773,22 +797,21 @@ class PivotTableViz(BaseViz):
         if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
             del df[DTTM_ALIAS]
 
-        aggfunc = self.form_data.get("pandas_aggfunc") or "sum"
-
-        # Ensure that Pandas's sum function mimics that of SQL.
-        if aggfunc == "sum":
-            aggfunc = lambda x: x.sum(min_count=1)
+        metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
+        aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
+        for metric in metrics:
+            aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)
 
         groupby = self.form_data.get("groupby")
         columns = self.form_data.get("columns")
         if self.form_data.get("transpose_pivot"):
             groupby, columns = columns, groupby
-        metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
+
         df = df.pivot_table(
             index=groupby,
             columns=columns,
             values=metrics,
-            aggfunc=aggfunc,
+            aggfunc=aggfuncs,
             margins=self.form_data.get("pivot_margins"),
         )
 
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 8290fbf..17e43d8 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -1284,3 +1284,41 @@ class TestBigNumberViz(SupersetTestCase):
         )
         data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
         assert np.isnan(data[2]["y"])
+
+
+class TestPivotTableViz(SupersetTestCase):
+    df = pd.DataFrame(
+        data={
+            "intcol": [1, 2, 3, None],
+            "floatcol": [0.1, 0.2, 0.3, None],
+            "strcol": ["a", "b", "c", None],
+        }
+    )
+
+    def test_get_aggfunc_numeric(self):
+        # is a sum function
+        func = viz.PivotTableViz.get_aggfunc("intcol", self.df, {})
+        assert hasattr(func, "__call__")
+        assert func(self.df["intcol"]) == 6
+
+        assert (
+            viz.PivotTableViz.get_aggfunc("intcol", self.df, {"pandas_aggfunc": "min"})
+            == "min"
+        )
+        assert (
+            viz.PivotTableViz.get_aggfunc(
+                "floatcol", self.df, {"pandas_aggfunc": "max"}
+            )
+            == "max"
+        )
+
+    def test_get_aggfunc_non_numeric(self):
+        assert viz.PivotTableViz.get_aggfunc("strcol", self.df, {}) == "max"
+        assert (
+            viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "sum"})
+            == "max"
+        )
+        assert (
+            viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"})
+            == "min"
+        )