You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2021/08/17 21:42:31 UTC

[superset] branch master updated: fix: improve pivot post-processing (#16289)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new ac8e54d  fix: improve pivot post-processing (#16289)
ac8e54d is described below

commit ac8e54d9094aaba05a87167da611b169464064da
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Tue Aug 17 14:41:22 2021 -0700

    fix: improve pivot post-processing (#16289)
    
    * fix: improve pivot post-processing
    
    * Add tests
    
    * Trim space from column name
---
 superset/charts/post_processing.py              | 281 ++++++---
 tests/unit_tests/charts/test_post_processing.py | 764 +++++++++++++++++++++++-
 2 files changed, 927 insertions(+), 118 deletions(-)

diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py
index b67d870..4919907 100644
--- a/superset/charts/post_processing.py
+++ b/superset/charts/post_processing.py
@@ -27,60 +27,151 @@ for these chart types.
 """
 
 from io import StringIO
-from typing import Any, Callable, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple
 
 import pandas as pd
 
 from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name
 
 
-def sql_like_sum(series: pd.Series) -> pd.Series:
+def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
     """
-    A SUM aggregation function that mimics the behavior from SQL.
-    """
-    return series.sum(min_count=1)
-
+    Sort columns when combining metrics.
 
-def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
+    MultiIndex labels have the metric name as the last element in the
+    tuple. We want to sort these according to the list of passed metrics.
     """
-    Pivot table.
-    """
-    if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
-        del df[DTTM_ALIAS]
-
-    metrics = [get_metric_name(m) for m in form_data["metrics"]]
-    aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
-    for metric in metrics:
-        aggfunc = form_data.get("pandas_aggfunc") or "sum"
-        if pd.api.types.is_numeric_dtype(df[metric]):
-            if aggfunc == "sum":
-                aggfunc = sql_like_sum
-        elif aggfunc not in {"min", "max"}:
-            aggfunc = "max"
-        aggfuncs[metric] = aggfunc
-
-    groupby = form_data.get("groupby") or []
-    columns = form_data.get("columns") or []
-    if form_data.get("transpose_pivot"):
-        groupby, columns = columns, groupby
-
-    df = df.pivot_table(
-        index=groupby,
-        columns=columns,
-        values=metrics,
-        aggfunc=aggfuncs,
-        margins=form_data.get("pivot_margins"),
-    )
-
-    # Display metrics side by side with each column
-    if form_data.get("combine_metric"):
-        df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
-
-    # flatten column names
-    df.columns = [
-        " ".join(str(name) for name in column) if isinstance(column, tuple) else column
-        for column in df.columns
-    ]
+    parts: List[Any] = list(label)
+    metric = parts[-1]
+    parts[-1] = metrics.index(metric)
+    return tuple(parts)
+
+
+def pivot_df(  # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
+    df: pd.DataFrame,
+    rows: List[str],
+    columns: List[str],
+    metrics: List[str],
+    aggfunc: str = "Sum",
+    transpose_pivot: bool = False,
+    combine_metrics: bool = False,
+    show_rows_total: bool = False,
+    show_columns_total: bool = False,
+    apply_metrics_on_rows: bool = False,
+) -> pd.DataFrame:
+    metric_name = f"Total ({aggfunc})"
+
+    if transpose_pivot:
+        rows, columns = columns, rows
+
+    # to apply the metrics on the rows we pivot the dataframe, apply the
+    # metrics to the columns, and pivot the dataframe back before
+    # returning it
+    if apply_metrics_on_rows:
+        rows, columns = columns, rows
+        axis = {"columns": 0, "rows": 1}
+    else:
+        axis = {"columns": 1, "rows": 0}
+
+    # pivot data; we'll compute totals and subtotals later
+    if rows or columns:
+        df = df.pivot_table(
+            index=rows,
+            columns=columns,
+            values=metrics,
+            aggfunc=pivot_v2_aggfunc_map[aggfunc],
+            margins=False,
+        )
+    else:
+        # if there's no rows nor columns we have a single value; update
+        # the index with the metric name so it shows up in the table
+        df.index = pd.Index([*df.index[:-1], metric_name], name="metric")
+
+    # if no rows were passed the metrics will be in the rows, so we
+    # need to move them back to columns
+    if columns and not rows:
+        df = df.stack().to_frame().T
+        df = df[metrics]
+        df.index = pd.Index([*df.index[:-1], metric_name], name="metric")
+
+    # combining metrics changes the column hierarchy, moving the metric
+    # from the top to the bottom, eg:
+    #
+    # ('SUM(col)', 'age', 'name') => ('age', 'name', 'SUM(col)')
+    if combine_metrics and isinstance(df.columns, pd.MultiIndex):
+        # move metrics to the lowest level
+        new_order = [*range(1, df.columns.nlevels), 0]
+        df = df.reorder_levels(new_order, axis=1)
+
+        # sort columns, combining metrics for each group
+        decorated_columns = [(col, i) for i, col in enumerate(df.columns)]
+        grouped_columns = sorted(
+            decorated_columns, key=lambda t: get_column_key(t[0], metrics)
+        )
+        indexes = [i for col, i in grouped_columns]
+        df = df[df.columns[indexes]]
+    elif rows:
+        # if metrics were not combined we sort the dataframe by the list
+        # of metrics defined by the user
+        df = df[metrics]
+
+    # compute fractions, if needed
+    if aggfunc.endswith(" as Fraction of Total"):
+        total = df.sum().sum()
+        df = df.astype(total.dtypes) / total
+    elif aggfunc.endswith(" as Fraction of Columns"):
+        total = df.sum(axis=axis["rows"])
+        df = df.astype(total.dtypes).div(total, axis=axis["columns"])
+    elif aggfunc.endswith(" as Fraction of Rows"):
+        total = df.sum(axis=axis["columns"])
+        df = df.astype(total.dtypes).div(total, axis=axis["rows"])
+
+    if show_rows_total:
+        # convert to a MultiIndex to simplify logic
+        if not isinstance(df.columns, pd.MultiIndex):
+            df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns])
+
+        # add subtotal for each group and overall total; we start from the
+        # overall group, and iterate deeper into subgroups
+        groups = df.columns
+        for level in range(df.columns.nlevels):
+            subgroups = {group[:level] for group in groups}
+            for subgroup in subgroups:
+                slice_ = df.columns.get_loc(subgroup)
+                subtotal = pivot_v2_aggfunc_map[aggfunc](df.iloc[:, slice_], axis=1)
+                depth = df.columns.nlevels - len(subgroup) - 1
+                total = metric_name if level == 0 else "Subtotal"
+                subtotal_name = tuple([*subgroup, total, *([""] * depth)])
+                # insert column after subgroup
+                df.insert(int(slice_.stop), subtotal_name, subtotal)
+
+    if rows and show_columns_total:
+        # convert to a MultiIndex to simplify logic
+        if not isinstance(df.index, pd.MultiIndex):
+            df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index])
+
+        # add subtotal for each group and overall total; we start from the
+        # overall group, and iterate deeper into subgroups
+        groups = df.index
+        for level in range(df.index.nlevels):
+            subgroups = {group[:level] for group in groups}
+            for subgroup in subgroups:
+                slice_ = df.index.get_loc(subgroup)
+                subtotal = pivot_v2_aggfunc_map[aggfunc](
+                    df.iloc[slice_, :].apply(pd.to_numeric), axis=0
+                )
+                depth = df.index.nlevels - len(subgroup) - 1
+                total = metric_name if level == 0 else "Subtotal"
+                subtotal.name = tuple([*subgroup, total, *([""] * depth)])
+                # insert row after subgroup
+                df = pd.concat(
+                    [df[: slice_.stop], subtotal.to_frame().T, df[slice_.stop :]]
+                )
+
+    # if we want to apply the metrics on the rows we need to pivot the
+    # dataframe back
+    if apply_metrics_on_rows:
+        df = df.T
 
     return df
 
@@ -125,61 +216,49 @@ def pivot_table_v2(  # pylint: disable=too-many-branches
     if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
         del df[DTTM_ALIAS]
 
-    # TODO (betodealmeida): implement metricsLayout
-    metrics = [get_metric_name(m) for m in form_data["metrics"]]
-    aggregate_function = form_data.get("aggregateFunction", "Sum")
-    groupby = form_data.get("groupbyRows") or []
-    columns = form_data.get("groupbyColumns") or []
-    if form_data.get("transposePivot"):
-        groupby, columns = columns, groupby
-
-    df = df.pivot_table(
-        index=groupby,
-        columns=columns,
-        values=metrics,
-        aggfunc=pivot_v2_aggfunc_map[aggregate_function],
-        margins=True,
+    return pivot_df(
+        df,
+        rows=form_data.get("groupbyRows") or [],
+        columns=form_data.get("groupbyColumns") or [],
+        metrics=[get_metric_name(m) for m in form_data["metrics"]],
+        aggfunc=form_data.get("aggregateFunction", "Sum"),
+        transpose_pivot=bool(form_data.get("transposePivot")),
+        combine_metrics=bool(form_data.get("combineMetric")),
+        show_rows_total=bool(form_data.get("rowTotals")),
+        show_columns_total=bool(form_data.get("colTotals")),
+        apply_metrics_on_rows=form_data.get("metricsLayout") == "ROWS",
     )
 
-    # The pandas `pivot_table` method either brings both row/column
-    # totals, or none at all. We pass `margin=True` to get both, and
-    # remove any dimension that was not requests.
-    if columns and not form_data.get("rowTotals"):
-        df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
-    if groupby and not form_data.get("colTotals"):
-        df = df[:-1]
-
-    # Compute fractions, if needed. If `colTotals` or `rowTotals` are
-    # present we need to adjust for including them in the sum
-    if aggregate_function.endswith(" as Fraction of Total"):
-        total = df.sum().sum()
-        df = df.astype(total.dtypes) / total
-        if form_data.get("colTotals"):
-            df *= 2
-        if form_data.get("rowTotals"):
-            df *= 2
-    elif aggregate_function.endswith(" as Fraction of Columns"):
-        total = df.sum(axis=0)
-        df = df.astype(total.dtypes).div(total, axis=1)
-        if form_data.get("colTotals"):
-            df *= 2
-    elif aggregate_function.endswith(" as Fraction of Rows"):
-        total = df.sum(axis=1)
-        df = df.astype(total.dtypes).div(total, axis=0)
-        if form_data.get("rowTotals"):
-            df *= 2
-
-    # Display metrics side by side with each column
-    if form_data.get("combineMetric"):
-        df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
-
-    # flatten column names
-    df.columns = [
-        " ".join(str(name) for name in column) if isinstance(column, tuple) else column
-        for column in df.columns
-    ]
 
-    return df
+def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
+    """
+    Pivot table (v1).
+    """
+    if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
+        del df[DTTM_ALIAS]
+
+    # v1 func names => v2 func names
+    func_map = {
+        "sum": "Sum",
+        "mean": "Average",
+        "min": "Minimum",
+        "max": "Maximum",
+        "std": "Sample Standard Deviation",
+        "var": "Sample Variance",
+    }
+
+    return pivot_df(
+        df,
+        rows=form_data.get("groupby") or [],
+        columns=form_data.get("columns") or [],
+        metrics=[get_metric_name(m) for m in form_data["metrics"]],
+        aggfunc=func_map.get(form_data.get("pandas_aggfunc", "sum"), "Sum"),
+        transpose_pivot=bool(form_data.get("transpose_pivot")),
+        combine_metrics=bool(form_data.get("combine_metric")),
+        show_rows_total=bool(form_data.get("pivot_margins")),
+        show_columns_total=bool(form_data.get("pivot_margins")),
+        apply_metrics_on_rows=False,
+    )
 
 
 post_processors = {
@@ -203,6 +282,14 @@ def apply_post_process(
         df = pd.read_csv(StringIO(query["data"]))
         processed_df = post_processor(df, form_data)
 
+        # flatten column names
+        processed_df.columns = [
+            " ".join(str(name) for name in column).strip()
+            if isinstance(column, tuple)
+            else column
+            for column in processed_df.columns
+        ]
+
         buf = StringIO()
         processed_df.to_csv(buf)
         buf.seek(0)
diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py
index dc2f9a1..9463577 100644
--- a/tests/unit_tests/charts/test_post_processing.py
+++ b/tests/unit_tests/charts/test_post_processing.py
@@ -18,7 +18,9 @@
 import copy
 from typing import Any, Dict
 
-from superset.charts.post_processing import apply_post_process
+import pandas as pd
+
+from superset.charts.post_processing import apply_post_process, pivot_df
 from superset.utils.core import GenericDataType, QueryStatus
 
 RESULT: Dict[str, Any] = {
@@ -149,7 +151,8 @@ LIMIT 50000;
                     "Births PA",
                     "Births TX",
                     "Births other",
-                    "Births All",
+                    "Births Subtotal",
+                    "Total (Sum)",
                 ],
                 "coltypes": [
                     GenericDataType.NUMERIC,
@@ -164,11 +167,12 @@ LIMIT 50000;
                     GenericDataType.NUMERIC,
                     GenericDataType.NUMERIC,
                     GenericDataType.NUMERIC,
+                    GenericDataType.NUMERIC,
                 ],
-                "data": """gender,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births All
-boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355
-girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308
-All,8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663
+                "data": """,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births Subtotal,Total (Sum)
+boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355,48133355
+girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308,32546308
+Total (Sum),8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663,80679663
 """,
                 "applied_filters": [],
                 "rejected_filters": [],
@@ -199,7 +203,7 @@ def test_pivot_table_v2():
                 "optionName": "metric_11",
             }
         ],
-        "metricsLayout": "ROWS",
+        "metricsLayout": "COLUMNS",
         "rowOrder": "key_a_to_z",
         "rowTotals": True,
         "row_limit": 50000,
@@ -237,28 +241,746 @@ LIMIT 50000;
                 "status": QueryStatus.SUCCESS,
                 "stacktrace": None,
                 "rowcount": 12,
-                "colnames": ["All Births", "boy Births", "girl Births"],
+                "colnames": [
+                    "boy Births",
+                    "boy Subtotal",
+                    "girl Births",
+                    "girl Subtotal",
+                    "Total (Sum as Fraction of Rows)",
+                ],
                 "coltypes": [
                     GenericDataType.NUMERIC,
                     GenericDataType.NUMERIC,
                     GenericDataType.NUMERIC,
+                    GenericDataType.NUMERIC,
+                    GenericDataType.NUMERIC,
                 ],
-                "data": """state,All Births,boy Births,girl Births
-All,1.0,0.5965983645717509,0.40340163542824914
-CA,1.0,0.6035190113962805,0.3964809886037195
-FL,1.0,0.5998988615985903,0.4001011384014097
-IL,1.0,0.5935315085862012,0.40646849141379887
-MA,1.0,0.6041192663655611,0.3958807336344389
-MI,1.0,0.5937482960898133,0.4062517039101867
-NJ,1.0,0.5995276800165239,0.40047231998347604
-NY,1.0,0.6084372844307357,0.39156271556926425
-OH,1.0,0.5942152416021308,0.40578475839786915
-PA,1.0,0.596724682935987,0.40327531706401293
-TX,1.0,0.5887794344385264,0.41122056556147357
-other,1.0,0.5941503507105172,0.40584964928948275
+                "data": """,boy Births,boy Subtotal,girl Births,girl Subtotal,Total (Sum as Fraction of Rows)
+CA,0.6035190113962805,0.6035190113962805,0.3964809886037195,0.3964809886037195,1.0
+FL,0.5998988615985903,0.5998988615985903,0.4001011384014097,0.4001011384014097,1.0
+IL,0.5935315085862012,0.5935315085862012,0.40646849141379887,0.40646849141379887,1.0
+MA,0.6041192663655611,0.6041192663655611,0.3958807336344389,0.3958807336344389,1.0
+MI,0.5937482960898133,0.5937482960898133,0.4062517039101867,0.4062517039101867,1.0
+NJ,0.5995276800165239,0.5995276800165239,0.40047231998347604,0.40047231998347604,1.0
+NY,0.6084372844307357,0.6084372844307357,0.39156271556926425,0.39156271556926425,1.0
+OH,0.5942152416021308,0.5942152416021308,0.40578475839786915,0.40578475839786915,1.0
+PA,0.596724682935987,0.596724682935987,0.40327531706401293,0.40327531706401293,1.0
+TX,0.5887794344385264,0.5887794344385264,0.41122056556147357,0.41122056556147357,1.0
+other,0.5941503507105172,0.5941503507105172,0.40584964928948275,0.40584964928948275,1.0
+Total (Sum as Fraction of Rows),6.576651618170867,6.576651618170867,4.423348381829133,4.423348381829133,11.0
 """,
                 "applied_filters": [],
                 "rejected_filters": [],
             }
         ],
     }
+
+
+def test_pivot_df_no_cols_no_rows_single_metric():
+    """
+    Pivot table when no cols/rows and 1 metric are selected.
+    """
+    # when no cols/rows are selected there are no groupbys in the query,
+    # and the data has only the metric(s)
+    df = pd.DataFrame.from_dict({"SUM(num)": {0: 80679663}})
+    assert (
+        df.to_markdown()
+        == """
+|    |    SUM(num) |
+|---:|------------:|
+|  0 | 8.06797e+07 |
+    """.strip()
+    )
+
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |    SUM(num) |
+|:------------|------------:|
+| Total (Sum) | 8.06797e+07 |
+    """.strip()
+    )
+
+    # tranpose_pivot and combine_metrics do nothing in this case
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |    SUM(num) |
+|:------------|------------:|
+| Total (Sum) | 8.06797e+07 |
+    """.strip()
+    )
+
+    # apply_metrics_on_rows will pivot the table, moving the metrics
+    # to rows
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|          |   Total (Sum) |
+|:---------|--------------:|
+| SUM(num) |   8.06797e+07 |
+    """.strip()
+    )
+
+    # showing totals
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |   ('SUM(num)',) |   ('Total (Sum)',) |
+|:------------|----------------:|-------------------:|
+| Total (Sum) |     8.06797e+07 |        8.06797e+07 |
+    """.strip()
+    )
+
+
+def test_pivot_df_no_cols_no_rows_two_metrics():
+    """
+    Pivot table when no cols/rows and 2 metrics are selected.
+    """
+    # when no cols/rows are selected there are no groupbys in the query,
+    # and the data has only the metrics
+    df = pd.DataFrame.from_dict({"SUM(num)": {0: 80679663}, "MAX(num)": {0: 37296}})
+    assert (
+        df.to_markdown()
+        == """
+|    |    SUM(num) |   MAX(num) |
+|---:|------------:|-----------:|
+|  0 | 8.06797e+07 |      37296 |
+    """.strip()
+    )
+
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |    SUM(num) |   MAX(num) |
+|:------------|------------:|-----------:|
+| Total (Sum) | 8.06797e+07 |      37296 |
+    """.strip()
+    )
+
+    # tranpose_pivot and combine_metrics do nothing in this case
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |    SUM(num) |   MAX(num) |
+|:------------|------------:|-----------:|
+| Total (Sum) | 8.06797e+07 |      37296 |
+    """.strip()
+    )
+
+    # apply_metrics_on_rows will pivot the table, moving the metrics
+    # to rows
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|          |     Total (Sum) |
+|:---------|----------------:|
+| SUM(num) |     8.06797e+07 |
+| MAX(num) | 37296           |
+    """.strip()
+    )
+
+    # when showing totals we only add a column, since adding a row
+    # would be redundant
+    pivoted = pivot_df(
+        df,
+        rows=[],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |   ('SUM(num)',) |   ('MAX(num)',) |   ('Total (Sum)',) |
+|:------------|----------------:|----------------:|-------------------:|
+| Total (Sum) |     8.06797e+07 |           37296 |         8.0717e+07 |
+    """.strip()
+    )
+
+
+def test_pivot_df_single_row_two_metrics():
+    """
+    Pivot table when a single column and 2 metrics are selected.
+    """
+    df = pd.DataFrame.from_dict(
+        {
+            "gender": {0: "girl", 1: "boy"},
+            "SUM(num)": {0: 118065, 1: 47123},
+            "MAX(num)": {0: 2588, 1: 1280},
+        }
+    )
+    assert (
+        df.to_markdown()
+        == """
+|    | gender   |   SUM(num) |   MAX(num) |
+|---:|:---------|-----------:|-----------:|
+|  0 | girl     |     118065 |       2588 |
+|  1 | boy      |      47123 |       1280 |
+    """.strip()
+    )
+
+    pivoted = pivot_df(
+        df,
+        rows=["gender"],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| gender   |   SUM(num) |   MAX(num) |
+|:---------|-----------:|-----------:|
+| boy      |      47123 |       1280 |
+| girl     |     118065 |       2588 |
+    """.strip()
+    )
+
+    # transpose_pivot
+    pivoted = pivot_df(
+        df,
+        rows=["gender"],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| metric      |   ('SUM(num)', 'boy') |   ('SUM(num)', 'girl') |   ('MAX(num)', 'boy') |   ('MAX(num)', 'girl') |
+|:------------|----------------------:|-----------------------:|----------------------:|-----------------------:|
+| Total (Sum) |                 47123 |                 118065 |                  1280 |                   2588 |
+    """.strip()
+    )
+
+    # combine_metrics does nothing in this case
+    pivoted = pivot_df(
+        df,
+        rows=["gender"],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| gender   |   SUM(num) |   MAX(num) |
+|:---------|-----------:|-----------:|
+| boy      |      47123 |       1280 |
+| girl     |     118065 |       2588 |
+    """.strip()
+    )
+
+    # show totals
+    pivoted = pivot_df(
+        df,
+        rows=["gender"],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                  |   ('SUM(num)',) |   ('MAX(num)',) |   ('Total (Sum)',) |
+|:-----------------|----------------:|----------------:|-------------------:|
+| ('boy',)         |           47123 |            1280 |              48403 |
+| ('girl',)        |          118065 |            2588 |             120653 |
+| ('Total (Sum)',) |          165188 |            3868 |             169056 |
+    """.strip()
+    )
+
+    # apply_metrics_on_rows
+    pivoted = pivot_df(
+        df,
+        rows=["gender"],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                          |   Total (Sum) |
+|:-------------------------|--------------:|
+| ('SUM(num)', 'boy')      |         47123 |
+| ('SUM(num)', 'girl')     |        118065 |
+| ('SUM(num)', 'Subtotal') |        165188 |
+| ('MAX(num)', 'boy')      |          1280 |
+| ('MAX(num)', 'girl')     |          2588 |
+| ('MAX(num)', 'Subtotal') |          3868 |
+| ('Total (Sum)', '')      |        169056 |
+    """.strip()
+    )
+
+    # apply_metrics_on_rows with combine_metrics
+    pivoted = pivot_df(
+        df,
+        rows=["gender"],
+        columns=[],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=True,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                      |   Total (Sum) |
+|:---------------------|--------------:|
+| ('boy', 'SUM(num)')  |         47123 |
+| ('boy', 'MAX(num)')  |          1280 |
+| ('boy', 'Subtotal')  |         48403 |
+| ('girl', 'SUM(num)') |        118065 |
+| ('girl', 'MAX(num)') |          2588 |
+| ('girl', 'Subtotal') |        120653 |
+| ('Total (Sum)', '')  |        169056 |
+    """.strip()
+    )
+
+
+def test_pivot_df_complex():
+    """
+    Pivot table when a column, rows and 2 metrics are selected.
+    """
+    df = pd.DataFrame.from_dict(
+        {
+            "state": {
+                0: "CA",
+                1: "CA",
+                2: "CA",
+                3: "FL",
+                4: "CA",
+                5: "CA",
+                6: "FL",
+                7: "FL",
+                8: "FL",
+                9: "CA",
+                10: "FL",
+                11: "FL",
+            },
+            "gender": {
+                0: "girl",
+                1: "boy",
+                2: "girl",
+                3: "girl",
+                4: "girl",
+                5: "girl",
+                6: "boy",
+                7: "girl",
+                8: "girl",
+                9: "boy",
+                10: "boy",
+                11: "girl",
+            },
+            "name": {
+                0: "Amy",
+                1: "Edward",
+                2: "Sophia",
+                3: "Amy",
+                4: "Cindy",
+                5: "Dawn",
+                6: "Edward",
+                7: "Sophia",
+                8: "Dawn",
+                9: "Tony",
+                10: "Tony",
+                11: "Cindy",
+            },
+            "SUM(num)": {
+                0: 45426,
+                1: 31290,
+                2: 18859,
+                3: 14740,
+                4: 14149,
+                5: 11403,
+                6: 9395,
+                7: 7181,
+                8: 5089,
+                9: 3765,
+                10: 2673,
+                11: 1218,
+            },
+            "MAX(num)": {
+                0: 2227,
+                1: 1280,
+                2: 2588,
+                3: 854,
+                4: 842,
+                5: 1157,
+                6: 389,
+                7: 1187,
+                8: 461,
+                9: 598,
+                10: 247,
+                11: 217,
+            },
+        }
+    )
+    assert (
+        df.to_markdown()
+        == """
+|    | state   | gender   | name   |   SUM(num) |   MAX(num) |
+|---:|:--------|:---------|:-------|-----------:|-----------:|
+|  0 | CA      | girl     | Amy    |      45426 |       2227 |
+|  1 | CA      | boy      | Edward |      31290 |       1280 |
+|  2 | CA      | girl     | Sophia |      18859 |       2588 |
+|  3 | FL      | girl     | Amy    |      14740 |        854 |
+|  4 | CA      | girl     | Cindy  |      14149 |        842 |
+|  5 | CA      | girl     | Dawn   |      11403 |       1157 |
+|  6 | FL      | boy      | Edward |       9395 |        389 |
+|  7 | FL      | girl     | Sophia |       7181 |       1187 |
+|  8 | FL      | girl     | Dawn   |       5089 |        461 |
+|  9 | CA      | boy      | Tony   |       3765 |        598 |
+| 10 | FL      | boy      | Tony   |       2673 |        247 |
+| 11 | FL      | girl     | Cindy  |       1218 |        217 |
+    """.strip()
+    )
+
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                    |   ('SUM(num)', 'CA') |   ('SUM(num)', 'FL') |   ('MAX(num)', 'CA') |   ('MAX(num)', 'FL') |
+|:-------------------|---------------------:|---------------------:|---------------------:|---------------------:|
+| ('boy', 'Edward')  |                31290 |                 9395 |                 1280 |                  389 |
+| ('boy', 'Tony')    |                 3765 |                 2673 |                  598 |                  247 |
+| ('girl', 'Amy')    |                45426 |                14740 |                 2227 |                  854 |
+| ('girl', 'Cindy')  |                14149 |                 1218 |                  842 |                  217 |
+| ('girl', 'Dawn')   |                11403 |                 5089 |                 1157 |                  461 |
+| ('girl', 'Sophia') |                18859 |                 7181 |                 2588 |                 1187 |
+    """.strip()
+    )
+
+    # transpose_pivot
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+| state   |   ('SUM(num)', 'boy', 'Edward') |   ('SUM(num)', 'boy', 'Tony') |   ('SUM(num)', 'girl', 'Amy') |   ('SUM(num)', 'girl', 'Cindy') |   ('SUM(num)', 'girl', 'Dawn') |   ('SUM(num)', 'girl', 'Sophia') |   ('MAX(num)', 'boy', 'Edward') |   ('MAX(num)', 'boy', 'Tony') |   ('MAX(num)', 'girl', 'Amy') |   ('MAX(num)', 'girl', 'Cindy') |   ('MAX(num)', 'girl', 'Dawn') |   ('MAX(num)', 'girl', 'Sophia') |
+|:--------|--------------------------------:|------------------------------:|------------------------------:|--------------------------------:|-------------------------------:|---------------------------------:|--------------------------------:|------------------------------:|------------------------------:|--------------------------------:|-------------------------------:|---------------------------------:|
+| CA      |                           31290 |                          3765 |                         45426 |                           14149 |                          11403 |                            18859 |                            1280 |                           598 |                          2227 |                             842 |                           1157 |                             2588 |
+| FL      |                            9395 |                          2673 |                         14740 |                            1218 |                           5089 |                             7181 |                             389 |                           247 |                           854 |                             217 |                            461 |                             1187 |
+    """.strip()
+    )
+
+    # combine_metrics
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                    |   ('CA', 'SUM(num)') |   ('CA', 'MAX(num)') |   ('FL', 'SUM(num)') |   ('FL', 'MAX(num)') |
+|:-------------------|---------------------:|---------------------:|---------------------:|---------------------:|
+| ('boy', 'Edward')  |                31290 |                 1280 |                 9395 |                  389 |
+| ('boy', 'Tony')    |                 3765 |                  598 |                 2673 |                  247 |
+| ('girl', 'Amy')    |                45426 |                 2227 |                14740 |                  854 |
+| ('girl', 'Cindy')  |                14149 |                  842 |                 1218 |                  217 |
+| ('girl', 'Dawn')   |                11403 |                 1157 |                 5089 |                  461 |
+| ('girl', 'Sophia') |                18859 |                 2588 |                 7181 |                 1187 |
+    """.strip()
+    )
+
+    # show totals
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                      |   ('SUM(num)', 'CA') |   ('SUM(num)', 'FL') |   ('SUM(num)', 'Subtotal') |   ('MAX(num)', 'CA') |   ('MAX(num)', 'FL') |   ('MAX(num)', 'Subtotal') |   ('Total (Sum)', '') |
+|:---------------------|---------------------:|---------------------:|---------------------------:|---------------------:|---------------------:|---------------------------:|----------------------:|
+| ('boy', 'Edward')    |                31290 |                 9395 |                      40685 |                 1280 |                  389 |                       1669 |                 42354 |
+| ('boy', 'Tony')      |                 3765 |                 2673 |                       6438 |                  598 |                  247 |                        845 |                  7283 |
+| ('boy', 'Subtotal')  |                35055 |                12068 |                      47123 |                 1878 |                  636 |                       2514 |                 49637 |
+| ('girl', 'Amy')      |                45426 |                14740 |                      60166 |                 2227 |                  854 |                       3081 |                 63247 |
+| ('girl', 'Cindy')    |                14149 |                 1218 |                      15367 |                  842 |                  217 |                       1059 |                 16426 |
+| ('girl', 'Dawn')     |                11403 |                 5089 |                      16492 |                 1157 |                  461 |                       1618 |                 18110 |
+| ('girl', 'Sophia')   |                18859 |                 7181 |                      26040 |                 2588 |                 1187 |                       3775 |                 29815 |
+| ('girl', 'Subtotal') |                89837 |                28228 |                     118065 |                 6814 |                 2719 |                       9533 |                127598 |
+| ('Total (Sum)', '')  |               124892 |                40296 |                     165188 |                 8692 |                 3355 |                      12047 |                177235 |
+    """.strip()
+    )
+
+    # apply_metrics_on_rows
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                                |    CA |    FL |
+|:-------------------------------|------:|------:|
+| ('SUM(num)', 'boy', 'Edward')  | 31290 |  9395 |
+| ('SUM(num)', 'boy', 'Tony')    |  3765 |  2673 |
+| ('SUM(num)', 'girl', 'Amy')    | 45426 | 14740 |
+| ('SUM(num)', 'girl', 'Cindy')  | 14149 |  1218 |
+| ('SUM(num)', 'girl', 'Dawn')   | 11403 |  5089 |
+| ('SUM(num)', 'girl', 'Sophia') | 18859 |  7181 |
+| ('MAX(num)', 'boy', 'Edward')  |  1280 |   389 |
+| ('MAX(num)', 'boy', 'Tony')    |   598 |   247 |
+| ('MAX(num)', 'girl', 'Amy')    |  2227 |   854 |
+| ('MAX(num)', 'girl', 'Cindy')  |   842 |   217 |
+| ('MAX(num)', 'girl', 'Dawn')   |  1157 |   461 |
+| ('MAX(num)', 'girl', 'Sophia') |  2588 |  1187 |
+    """.strip()
+    )
+
+    # apply_metrics_on_rows with combine_metrics
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=False,
+        combine_metrics=True,
+        show_rows_total=False,
+        show_columns_total=False,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                                |    CA |    FL |
+|:-------------------------------|------:|------:|
+| ('boy', 'Edward', 'SUM(num)')  | 31290 |  9395 |
+| ('boy', 'Edward', 'MAX(num)')  |  1280 |   389 |
+| ('boy', 'Tony', 'SUM(num)')    |  3765 |  2673 |
+| ('boy', 'Tony', 'MAX(num)')    |   598 |   247 |
+| ('girl', 'Amy', 'SUM(num)')    | 45426 | 14740 |
+| ('girl', 'Amy', 'MAX(num)')    |  2227 |   854 |
+| ('girl', 'Cindy', 'SUM(num)')  | 14149 |  1218 |
+| ('girl', 'Cindy', 'MAX(num)')  |   842 |   217 |
+| ('girl', 'Dawn', 'SUM(num)')   | 11403 |  5089 |
+| ('girl', 'Dawn', 'MAX(num)')   |  1157 |   461 |
+| ('girl', 'Sophia', 'SUM(num)') | 18859 |  7181 |
+| ('girl', 'Sophia', 'MAX(num)') |  2588 |  1187 |
+    """.strip()
+    )
+
+    # everything
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum",
+        transpose_pivot=True,
+        combine_metrics=True,
+        show_rows_total=True,
+        show_columns_total=True,
+        apply_metrics_on_rows=True,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                     |   ('boy', 'Edward') |   ('boy', 'Tony') |   ('boy', 'Subtotal') |   ('girl', 'Amy') |   ('girl', 'Cindy') |   ('girl', 'Dawn') |   ('girl', 'Sophia') |   ('girl', 'Subtotal') |   ('Total (Sum)', '') |
+|:--------------------|--------------------:|------------------:|----------------------:|------------------:|--------------------:|-------------------:|---------------------:|-----------------------:|----------------------:|
+| ('CA', 'SUM(num)')  |               31290 |              3765 |                 35055 |             45426 |               14149 |              11403 |                18859 |                  89837 |                124892 |
+| ('CA', 'MAX(num)')  |                1280 |               598 |                  1878 |              2227 |                 842 |               1157 |                 2588 |                   6814 |                  8692 |
+| ('CA', 'Subtotal')  |               32570 |              4363 |                 36933 |             47653 |               14991 |              12560 |                21447 |                  96651 |                133584 |
+| ('FL', 'SUM(num)')  |                9395 |              2673 |                 12068 |             14740 |                1218 |               5089 |                 7181 |                  28228 |                 40296 |
+| ('FL', 'MAX(num)')  |                 389 |               247 |                   636 |               854 |                 217 |                461 |                 1187 |                   2719 |                  3355 |
+| ('FL', 'Subtotal')  |                9784 |              2920 |                 12704 |             15594 |                1435 |               5550 |                 8368 |                  30947 |                 43651 |
+| ('Total (Sum)', '') |               42354 |              7283 |                 49637 |             63247 |               16426 |              18110 |                29815 |                 127598 |                177235 |
+    """.strip()
+    )
+
+    # fraction
+    pivoted = pivot_df(
+        df,
+        rows=["gender", "name"],
+        columns=["state"],
+        metrics=["SUM(num)", "MAX(num)"],
+        aggfunc="Sum as Fraction of Columns",
+        transpose_pivot=False,
+        combine_metrics=False,
+        show_rows_total=False,
+        show_columns_total=True,
+        apply_metrics_on_rows=False,
+    )
+    assert (
+        pivoted.to_markdown()
+        == """
+|                                            |   ('SUM(num)', 'CA') |   ('SUM(num)', 'FL') |   ('MAX(num)', 'CA') |   ('MAX(num)', 'FL') |
+|:-------------------------------------------|---------------------:|---------------------:|---------------------:|---------------------:|
+| ('boy', 'Edward')                          |            0.250536  |            0.23315   |            0.147262  |            0.115946  |
+| ('boy', 'Tony')                            |            0.030146  |            0.0663341 |            0.0687989 |            0.0736215 |
+| ('boy', 'Subtotal')                        |            0.280683  |            0.299484  |            0.216061  |            0.189568  |
+| ('girl', 'Amy')                            |            0.363722  |            0.365793  |            0.256213  |            0.254545  |
+| ('girl', 'Cindy')                          |            0.11329   |            0.0302263 |            0.0968707 |            0.0646796 |
+| ('girl', 'Dawn')                           |            0.0913029 |            0.12629   |            0.133111  |            0.137407  |
+| ('girl', 'Sophia')                         |            0.151002  |            0.178206  |            0.297745  |            0.3538    |
+| ('girl', 'Subtotal')                       |            0.719317  |            0.700516  |            0.783939  |            0.810432  |
+| ('Total (Sum as Fraction of Columns)', '') |            1         |            1         |            1         |            1         |
+    """.strip()
+    )