You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/07/08 10:36:29 UTC

[incubator-superset] branch master updated: feat(chart-data-api): make pivoted columns flattenable (#10255)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new baeacc3  feat(chart-data-api): make pivoted columns flattenable (#10255)
baeacc3 is described below

commit baeacc3c560dbd2ac9543912ca5559e112118d68
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Wed Jul 8 13:35:53 2020 +0300

    feat(chart-data-api): make pivoted columns flattenable (#10255)
    
    * feat(chart-data-api): make pivoted columns flattenable
    
    * Linting + improve tests
---
 superset/charts/schemas.py              |   2 -
 superset/utils/pandas_postprocessing.py |  42 ++++++++++--
 tests/pandas_postprocessing_tests.py    | 111 ++++++++++++++++++++++++++++----
 3 files changed, 134 insertions(+), 21 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 60cec21..8ab4859 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -414,8 +414,6 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
         fields.String(
             allow_none=False, description="Columns to group by on the table columns",
         ),
-        minLength=1,
-        required=True,
     )
     metric_fill_value = fields.Number(
         description="Value to replace missing values with in aggregate calculations.",
diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py
index e62b393..b693977 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -72,13 +72,38 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
 )
 
 
+def _flatten_column_after_pivot(
+    column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
+) -> str:
+    """
+    Function for flattening column names into a single string. This step is necessary
+    to be able to properly serialize a DataFrame. If the column is a string, return
+    element unchanged. For multi-element columns, join column elements with a comma,
+    with the exception of pivots made with a single aggregate, in which case the
+    aggregate column name is omitted.
+
+    :param column: single element from `DataFrame.columns`
+    :param aggregates: aggregates
+    :return:
+    """
+    if isinstance(column, str):
+        return column
+    if len(column) == 1:
+        return column[0]
+    if len(aggregates) == 1 and len(column) > 1:
+        # drop aggregate for single aggregate pivots with multiple groupings
+        # from column name (aggregates always come first in column name)
+        column = column[1:]
+    return ", ".join(column)
+
+
 def validate_column_args(*argnames: str) -> Callable[..., Any]:
     def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
         def wrapped(df: DataFrame, **options: Any) -> Any:
             columns = df.columns.tolist()
             for name in argnames:
                 if name in options and not all(
-                    elem in columns for elem in options[name]
+                    elem in columns for elem in options.get(name) or []
                 ):
                     raise QueryObjectValidationError(
                         _("Referenced columns not available in DataFrame.")
@@ -154,14 +179,15 @@ def _append_columns(
 def pivot(  # pylint: disable=too-many-arguments
     df: DataFrame,
     index: List[str],
-    columns: List[str],
     aggregates: Dict[str, Dict[str, Any]],
+    columns: Optional[List[str]] = None,
     metric_fill_value: Optional[Any] = None,
     column_fill_value: Optional[str] = None,
     drop_missing_columns: Optional[bool] = True,
     combine_value_with_metric: bool = False,
     marginal_distributions: Optional[bool] = None,
     marginal_distribution_name: Optional[str] = None,
+    flatten_columns: bool = True,
 ) -> DataFrame:
     """
     Perform a pivot operation on a DataFrame.
@@ -179,6 +205,7 @@ def pivot(  # pylint: disable=too-many-arguments
     :param marginal_distributions: Add totals for row/column. Default to False
     :param marginal_distribution_name: Name of row/column with marginal distribution.
            Default to 'All'.
+    :param flatten_columns: Convert column names to strings
     :return: A pivot table
     :raises ChartDataValidationError: If the request in incorrect
     """
@@ -186,10 +213,6 @@ def pivot(  # pylint: disable=too-many-arguments
         raise QueryObjectValidationError(
             _("Pivot operation requires at least one index")
         )
-    if not columns:
-        raise QueryObjectValidationError(
-            _("Pivot operation requires at least one column")
-        )
     if not aggregates:
         raise QueryObjectValidationError(
             _("Pivot operation must include at least one aggregate")
@@ -218,6 +241,13 @@ def pivot(  # pylint: disable=too-many-arguments
     if combine_value_with_metric:
         df = df.stack(0).unstack()
 
+    # Make index regular column
+    if flatten_columns:
+        df.columns = [
+            _flatten_column_after_pivot(col, aggregates) for col in df.columns
+        ]
+    # return index as regular column
+    df.reset_index(level=0, inplace=True)
     return df
 
 
diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py
index 839b227..87d2cc1 100644
--- a/tests/pandas_postprocessing_tests.py
+++ b/tests/pandas_postprocessing_tests.py
@@ -26,6 +26,12 @@ from superset.utils import pandas_postprocessing as proc
 from .base_tests import SupersetTestCase
 from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
 
+AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
+AGGREGATES_MULTIPLE = {
+    "idx_nulls": {"operator": "sum"},
+    "asc_idx": {"operator": "mean"},
+}
+
 
 def series_to_list(series: Series) -> List[Any]:
     """
@@ -57,33 +63,99 @@ def round_floats(
 
 
 class TestPostProcessing(SupersetTestCase):
-    def test_pivot(self):
-        aggregates = {"idx_nulls": {"operator": "sum"}}
+    def test_flatten_column_after_pivot(self):
+        """
+        Test pivot column flattening function
+        """
+        # single aggregate cases
+        self.assertEqual(
+            proc._flatten_column_after_pivot(
+                aggregates=AGGREGATES_SINGLE, column="idx_nulls",
+            ),
+            "idx_nulls",
+        )
+        self.assertEqual(
+            proc._flatten_column_after_pivot(
+                aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
+            ),
+            "col1",
+        )
+        self.assertEqual(
+            proc._flatten_column_after_pivot(
+                aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", "col2"),
+            ),
+            "col1, col2",
+        )
+
+        # Multiple aggregate cases
+        self.assertEqual(
+            proc._flatten_column_after_pivot(
+                aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
+            ),
+            "idx_nulls, asc_idx, col1",
+        )
+        self.assertEqual(
+            proc._flatten_column_after_pivot(
+                aggregates=AGGREGATES_MULTIPLE,
+                column=("idx_nulls", "asc_idx", "col1", "col2"),
+            ),
+            "idx_nulls, asc_idx, col1, col2",
+        )
+
+    def test_pivot_without_columns(self):
+        """
+        Make sure pivot without columns returns correct DataFrame
+        """
+        df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
+        self.assertListEqual(
+            df.columns.tolist(), ["name", "idx_nulls"],
+        )
+        self.assertEqual(len(df), 101)
+        self.assertEqual(df.sum()[1], 1050)
 
-        # regular pivot
+    def test_pivot_with_single_column(self):
+        """
+        Make sure pivot with single column returns correct DataFrame
+        """
         df = proc.pivot(
             df=categories_df,
             index=["name"],
             columns=["category"],
-            aggregates=aggregates,
+            aggregates=AGGREGATES_SINGLE,
         )
         self.assertListEqual(
-            df.columns.tolist(),
-            [("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")],
+            df.columns.tolist(), ["name", "cat0", "cat1", "cat2"],
         )
         self.assertEqual(len(df), 101)
-        self.assertEqual(df.sum()[0], 315)
+        self.assertEqual(df.sum()[1], 315)
 
-        # regular pivot
         df = proc.pivot(
             df=categories_df,
             index=["dept"],
             columns=["category"],
-            aggregates=aggregates,
+            aggregates=AGGREGATES_SINGLE,
+        )
+        self.assertListEqual(
+            df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"],
         )
         self.assertEqual(len(df), 5)
 
-        # fill value
+    def test_pivot_with_multiple_columns(self):
+        """
+        Make sure pivot with multiple columns returns correct DataFrame
+        """
+        df = proc.pivot(
+            df=categories_df,
+            index=["name"],
+            columns=["category", "dept"],
+            aggregates=AGGREGATES_SINGLE,
+        )
+        self.assertEqual(len(df.columns), 1 + 3 * 5)  # index + possible permutations
+
+    def test_pivot_fill_values(self):
+        """
+        Make sure pivot with fill values returns correct DataFrame
+        """
         df = proc.pivot(
             df=categories_df,
             index=["name"],
@@ -91,7 +163,20 @@ class TestPostProcessing(SupersetTestCase):
             metric_fill_value=1,
             aggregates={"idx_nulls": {"operator": "sum"}},
         )
-        self.assertEqual(df.sum()[0], 382)
+        self.assertEqual(df.sum()[1], 382)
+
+    def test_pivot_exceptions(self):
+        """
+        Make sure pivot raises correct Exceptions
+        """
+        # Missing index
+        self.assertRaises(
+            TypeError,
+            proc.pivot,
+            df=categories_df,
+            columns=["dept"],
+            aggregates=AGGREGATES_SINGLE,
+        )
 
         # invalid index reference
         self.assertRaises(
@@ -100,7 +185,7 @@ class TestPostProcessing(SupersetTestCase):
             df=categories_df,
             index=["abc"],
             columns=["dept"],
-            aggregates=aggregates,
+            aggregates=AGGREGATES_SINGLE,
         )
 
         # invalid column reference
@@ -110,7 +195,7 @@ class TestPostProcessing(SupersetTestCase):
             df=categories_df,
             index=["dept"],
             columns=["abc"],
-            aggregates=aggregates,
+            aggregates=AGGREGATES_SINGLE,
         )
 
         # invalid aggregate options