You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2021/08/17 21:42:31 UTC
[superset] branch master updated: fix: improve pivot
post-processing (#16289)
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new ac8e54d fix: improve pivot post-processing (#16289)
ac8e54d is described below
commit ac8e54d9094aaba05a87167da611b169464064da
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Tue Aug 17 14:41:22 2021 -0700
fix: improve pivot post-processing (#16289)
* fix: improve pivot post-processing
* Add tests
* Trim space from column name
---
superset/charts/post_processing.py | 281 ++++++---
tests/unit_tests/charts/test_post_processing.py | 764 +++++++++++++++++++++++-
2 files changed, 927 insertions(+), 118 deletions(-)
diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py
index b67d870..4919907 100644
--- a/superset/charts/post_processing.py
+++ b/superset/charts/post_processing.py
@@ -27,60 +27,151 @@ for these chart types.
"""
from io import StringIO
-from typing import Any, Callable, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name
-def sql_like_sum(series: pd.Series) -> pd.Series:
+def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
"""
- A SUM aggregation function that mimics the behavior from SQL.
- """
- return series.sum(min_count=1)
-
+ Sort columns when combining metrics.
-def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
+ MultiIndex labels have the metric name as the last element in the
+ tuple. We want to sort these according to the list of passed metrics.
"""
- Pivot table.
- """
- if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
- del df[DTTM_ALIAS]
-
- metrics = [get_metric_name(m) for m in form_data["metrics"]]
- aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
- for metric in metrics:
- aggfunc = form_data.get("pandas_aggfunc") or "sum"
- if pd.api.types.is_numeric_dtype(df[metric]):
- if aggfunc == "sum":
- aggfunc = sql_like_sum
- elif aggfunc not in {"min", "max"}:
- aggfunc = "max"
- aggfuncs[metric] = aggfunc
-
- groupby = form_data.get("groupby") or []
- columns = form_data.get("columns") or []
- if form_data.get("transpose_pivot"):
- groupby, columns = columns, groupby
-
- df = df.pivot_table(
- index=groupby,
- columns=columns,
- values=metrics,
- aggfunc=aggfuncs,
- margins=form_data.get("pivot_margins"),
- )
-
- # Display metrics side by side with each column
- if form_data.get("combine_metric"):
- df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
-
- # flatten column names
- df.columns = [
- " ".join(str(name) for name in column) if isinstance(column, tuple) else column
- for column in df.columns
- ]
+ parts: List[Any] = list(label)
+ metric = parts[-1]
+ parts[-1] = metrics.index(metric)
+ return tuple(parts)
+
+
+def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
+ df: pd.DataFrame,
+ rows: List[str],
+ columns: List[str],
+ metrics: List[str],
+ aggfunc: str = "Sum",
+ transpose_pivot: bool = False,
+ combine_metrics: bool = False,
+ show_rows_total: bool = False,
+ show_columns_total: bool = False,
+ apply_metrics_on_rows: bool = False,
+) -> pd.DataFrame:
+ metric_name = f"Total ({aggfunc})"
+
+ if transpose_pivot:
+ rows, columns = columns, rows
+
+ # to apply the metrics on the rows we pivot the dataframe, apply the
+ # metrics to the columns, and pivot the dataframe back before
+ # returning it
+ if apply_metrics_on_rows:
+ rows, columns = columns, rows
+ axis = {"columns": 0, "rows": 1}
+ else:
+ axis = {"columns": 1, "rows": 0}
+
+ # pivot data; we'll compute totals and subtotals later
+ if rows or columns:
+ df = df.pivot_table(
+ index=rows,
+ columns=columns,
+ values=metrics,
+ aggfunc=pivot_v2_aggfunc_map[aggfunc],
+ margins=False,
+ )
+ else:
+ # if there's no rows nor columns we have a single value; update
+ # the index with the metric name so it shows up in the table
+ df.index = pd.Index([*df.index[:-1], metric_name], name="metric")
+
+ # if no rows were passed the metrics will be in the rows, so we
+ # need to move them back to columns
+ if columns and not rows:
+ df = df.stack().to_frame().T
+ df = df[metrics]
+ df.index = pd.Index([*df.index[:-1], metric_name], name="metric")
+
+ # combining metrics changes the column hierarchy, moving the metric
+ # from the top to the bottom, eg:
+ #
+ # ('SUM(col)', 'age', 'name') => ('age', 'name', 'SUM(col)')
+ if combine_metrics and isinstance(df.columns, pd.MultiIndex):
+ # move metrics to the lowest level
+ new_order = [*range(1, df.columns.nlevels), 0]
+ df = df.reorder_levels(new_order, axis=1)
+
+ # sort columns, combining metrics for each group
+ decorated_columns = [(col, i) for i, col in enumerate(df.columns)]
+ grouped_columns = sorted(
+ decorated_columns, key=lambda t: get_column_key(t[0], metrics)
+ )
+ indexes = [i for col, i in grouped_columns]
+ df = df[df.columns[indexes]]
+ elif rows:
+ # if metrics were not combined we sort the dataframe by the list
+ # of metrics defined by the user
+ df = df[metrics]
+
+ # compute fractions, if needed
+ if aggfunc.endswith(" as Fraction of Total"):
+ total = df.sum().sum()
+ df = df.astype(total.dtypes) / total
+ elif aggfunc.endswith(" as Fraction of Columns"):
+ total = df.sum(axis=axis["rows"])
+ df = df.astype(total.dtypes).div(total, axis=axis["columns"])
+ elif aggfunc.endswith(" as Fraction of Rows"):
+ total = df.sum(axis=axis["columns"])
+ df = df.astype(total.dtypes).div(total, axis=axis["rows"])
+
+ if show_rows_total:
+ # convert to a MultiIndex to simplify logic
+ if not isinstance(df.columns, pd.MultiIndex):
+ df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns])
+
+ # add subtotal for each group and overall total; we start from the
+ # overall group, and iterate deeper into subgroups
+ groups = df.columns
+ for level in range(df.columns.nlevels):
+ subgroups = {group[:level] for group in groups}
+ for subgroup in subgroups:
+ slice_ = df.columns.get_loc(subgroup)
+ subtotal = pivot_v2_aggfunc_map[aggfunc](df.iloc[:, slice_], axis=1)
+ depth = df.columns.nlevels - len(subgroup) - 1
+ total = metric_name if level == 0 else "Subtotal"
+ subtotal_name = tuple([*subgroup, total, *([""] * depth)])
+ # insert column after subgroup
+ df.insert(int(slice_.stop), subtotal_name, subtotal)
+
+ if rows and show_columns_total:
+ # convert to a MultiIndex to simplify logic
+ if not isinstance(df.index, pd.MultiIndex):
+ df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index])
+
+ # add subtotal for each group and overall total; we start from the
+ # overall group, and iterate deeper into subgroups
+ groups = df.index
+ for level in range(df.index.nlevels):
+ subgroups = {group[:level] for group in groups}
+ for subgroup in subgroups:
+ slice_ = df.index.get_loc(subgroup)
+ subtotal = pivot_v2_aggfunc_map[aggfunc](
+ df.iloc[slice_, :].apply(pd.to_numeric), axis=0
+ )
+ depth = df.index.nlevels - len(subgroup) - 1
+ total = metric_name if level == 0 else "Subtotal"
+ subtotal.name = tuple([*subgroup, total, *([""] * depth)])
+ # insert row after subgroup
+ df = pd.concat(
+ [df[: slice_.stop], subtotal.to_frame().T, df[slice_.stop :]]
+ )
+
+ # if we want to apply the metrics on the rows we need to pivot the
+ # dataframe back
+ if apply_metrics_on_rows:
+ df = df.T
return df
@@ -125,61 +216,49 @@ def pivot_table_v2( # pylint: disable=too-many-branches
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
- # TODO (betodealmeida): implement metricsLayout
- metrics = [get_metric_name(m) for m in form_data["metrics"]]
- aggregate_function = form_data.get("aggregateFunction", "Sum")
- groupby = form_data.get("groupbyRows") or []
- columns = form_data.get("groupbyColumns") or []
- if form_data.get("transposePivot"):
- groupby, columns = columns, groupby
-
- df = df.pivot_table(
- index=groupby,
- columns=columns,
- values=metrics,
- aggfunc=pivot_v2_aggfunc_map[aggregate_function],
- margins=True,
+ return pivot_df(
+ df,
+ rows=form_data.get("groupbyRows") or [],
+ columns=form_data.get("groupbyColumns") or [],
+ metrics=[get_metric_name(m) for m in form_data["metrics"]],
+ aggfunc=form_data.get("aggregateFunction", "Sum"),
+ transpose_pivot=bool(form_data.get("transposePivot")),
+ combine_metrics=bool(form_data.get("combineMetric")),
+ show_rows_total=bool(form_data.get("rowTotals")),
+ show_columns_total=bool(form_data.get("colTotals")),
+ apply_metrics_on_rows=form_data.get("metricsLayout") == "ROWS",
)
- # The pandas `pivot_table` method either brings both row/column
- # totals, or none at all. We pass `margin=True` to get both, and
- # remove any dimension that was not requests.
- if columns and not form_data.get("rowTotals"):
- df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
- if groupby and not form_data.get("colTotals"):
- df = df[:-1]
-
- # Compute fractions, if needed. If `colTotals` or `rowTotals` are
- # present we need to adjust for including them in the sum
- if aggregate_function.endswith(" as Fraction of Total"):
- total = df.sum().sum()
- df = df.astype(total.dtypes) / total
- if form_data.get("colTotals"):
- df *= 2
- if form_data.get("rowTotals"):
- df *= 2
- elif aggregate_function.endswith(" as Fraction of Columns"):
- total = df.sum(axis=0)
- df = df.astype(total.dtypes).div(total, axis=1)
- if form_data.get("colTotals"):
- df *= 2
- elif aggregate_function.endswith(" as Fraction of Rows"):
- total = df.sum(axis=1)
- df = df.astype(total.dtypes).div(total, axis=0)
- if form_data.get("rowTotals"):
- df *= 2
-
- # Display metrics side by side with each column
- if form_data.get("combineMetric"):
- df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
-
- # flatten column names
- df.columns = [
- " ".join(str(name) for name in column) if isinstance(column, tuple) else column
- for column in df.columns
- ]
- return df
+def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
+ """
+ Pivot table (v1).
+ """
+ if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
+ del df[DTTM_ALIAS]
+
+ # v1 func names => v2 func names
+ func_map = {
+ "sum": "Sum",
+ "mean": "Average",
+ "min": "Minimum",
+ "max": "Maximum",
+ "std": "Sample Standard Deviation",
+ "var": "Sample Variance",
+ }
+
+ return pivot_df(
+ df,
+ rows=form_data.get("groupby") or [],
+ columns=form_data.get("columns") or [],
+ metrics=[get_metric_name(m) for m in form_data["metrics"]],
+ aggfunc=func_map.get(form_data.get("pandas_aggfunc", "sum"), "Sum"),
+ transpose_pivot=bool(form_data.get("transpose_pivot")),
+ combine_metrics=bool(form_data.get("combine_metric")),
+ show_rows_total=bool(form_data.get("pivot_margins")),
+ show_columns_total=bool(form_data.get("pivot_margins")),
+ apply_metrics_on_rows=False,
+ )
post_processors = {
@@ -203,6 +282,14 @@ def apply_post_process(
df = pd.read_csv(StringIO(query["data"]))
processed_df = post_processor(df, form_data)
+ # flatten column names
+ processed_df.columns = [
+ " ".join(str(name) for name in column).strip()
+ if isinstance(column, tuple)
+ else column
+ for column in processed_df.columns
+ ]
+
buf = StringIO()
processed_df.to_csv(buf)
buf.seek(0)
diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py
index dc2f9a1..9463577 100644
--- a/tests/unit_tests/charts/test_post_processing.py
+++ b/tests/unit_tests/charts/test_post_processing.py
@@ -18,7 +18,9 @@
import copy
from typing import Any, Dict
-from superset.charts.post_processing import apply_post_process
+import pandas as pd
+
+from superset.charts.post_processing import apply_post_process, pivot_df
from superset.utils.core import GenericDataType, QueryStatus
RESULT: Dict[str, Any] = {
@@ -149,7 +151,8 @@ LIMIT 50000;
"Births PA",
"Births TX",
"Births other",
- "Births All",
+ "Births Subtotal",
+ "Total (Sum)",
],
"coltypes": [
GenericDataType.NUMERIC,
@@ -164,11 +167,12 @@ LIMIT 50000;
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
+ GenericDataType.NUMERIC,
],
- "data": """gender,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births All
-boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355
-girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308
-All,8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663
+ "data": """,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births Subtotal,Total (Sum)
+boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355,48133355
+girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308,32546308
+Total (Sum),8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663,80679663
""",
"applied_filters": [],
"rejected_filters": [],
@@ -199,7 +203,7 @@ def test_pivot_table_v2():
"optionName": "metric_11",
}
],
- "metricsLayout": "ROWS",
+ "metricsLayout": "COLUMNS",
"rowOrder": "key_a_to_z",
"rowTotals": True,
"row_limit": 50000,
@@ -237,28 +241,746 @@ LIMIT 50000;
"status": QueryStatus.SUCCESS,
"stacktrace": None,
"rowcount": 12,
- "colnames": ["All Births", "boy Births", "girl Births"],
+ "colnames": [
+ "boy Births",
+ "boy Subtotal",
+ "girl Births",
+ "girl Subtotal",
+ "Total (Sum as Fraction of Rows)",
+ ],
"coltypes": [
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
+ GenericDataType.NUMERIC,
+ GenericDataType.NUMERIC,
],
- "data": """state,All Births,boy Births,girl Births
-All,1.0,0.5965983645717509,0.40340163542824914
-CA,1.0,0.6035190113962805,0.3964809886037195
-FL,1.0,0.5998988615985903,0.4001011384014097
-IL,1.0,0.5935315085862012,0.40646849141379887
-MA,1.0,0.6041192663655611,0.3958807336344389
-MI,1.0,0.5937482960898133,0.4062517039101867
-NJ,1.0,0.5995276800165239,0.40047231998347604
-NY,1.0,0.6084372844307357,0.39156271556926425
-OH,1.0,0.5942152416021308,0.40578475839786915
-PA,1.0,0.596724682935987,0.40327531706401293
-TX,1.0,0.5887794344385264,0.41122056556147357
-other,1.0,0.5941503507105172,0.40584964928948275
+ "data": """,boy Births,boy Subtotal,girl Births,girl Subtotal,Total (Sum as Fraction of Rows)
+CA,0.6035190113962805,0.6035190113962805,0.3964809886037195,0.3964809886037195,1.0
+FL,0.5998988615985903,0.5998988615985903,0.4001011384014097,0.4001011384014097,1.0
+IL,0.5935315085862012,0.5935315085862012,0.40646849141379887,0.40646849141379887,1.0
+MA,0.6041192663655611,0.6041192663655611,0.3958807336344389,0.3958807336344389,1.0
+MI,0.5937482960898133,0.5937482960898133,0.4062517039101867,0.4062517039101867,1.0
+NJ,0.5995276800165239,0.5995276800165239,0.40047231998347604,0.40047231998347604,1.0
+NY,0.6084372844307357,0.6084372844307357,0.39156271556926425,0.39156271556926425,1.0
+OH,0.5942152416021308,0.5942152416021308,0.40578475839786915,0.40578475839786915,1.0
+PA,0.596724682935987,0.596724682935987,0.40327531706401293,0.40327531706401293,1.0
+TX,0.5887794344385264,0.5887794344385264,0.41122056556147357,0.41122056556147357,1.0
+other,0.5941503507105172,0.5941503507105172,0.40584964928948275,0.40584964928948275,1.0
+Total (Sum as Fraction of Rows),6.576651618170867,6.576651618170867,4.423348381829133,4.423348381829133,11.0
""",
"applied_filters": [],
"rejected_filters": [],
}
],
}
+
+
+def test_pivot_df_no_cols_no_rows_single_metric():
+ """
+ Pivot table when no cols/rows and 1 metric are selected.
+ """
+ # when no cols/rows are selected there are no groupbys in the query,
+ # and the data has only the metric(s)
+ df = pd.DataFrame.from_dict({"SUM(num)": {0: 80679663}})
+ assert (
+ df.to_markdown()
+ == """
+| | SUM(num) |
+|---:|------------:|
+| 0 | 8.06797e+07 |
+ """.strip()
+ )
+
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | SUM(num) |
+|:------------|------------:|
+| Total (Sum) | 8.06797e+07 |
+ """.strip()
+ )
+
+ # tranpose_pivot and combine_metrics do nothing in this case
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | SUM(num) |
+|:------------|------------:|
+| Total (Sum) | 8.06797e+07 |
+ """.strip()
+ )
+
+ # apply_metrics_on_rows will pivot the table, moving the metrics
+ # to rows
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | Total (Sum) |
+|:---------|--------------:|
+| SUM(num) | 8.06797e+07 |
+ """.strip()
+ )
+
+ # showing totals
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | ('SUM(num)',) | ('Total (Sum)',) |
+|:------------|----------------:|-------------------:|
+| Total (Sum) | 8.06797e+07 | 8.06797e+07 |
+ """.strip()
+ )
+
+
+def test_pivot_df_no_cols_no_rows_two_metrics():
+ """
+ Pivot table when no cols/rows and 2 metrics are selected.
+ """
+ # when no cols/rows are selected there are no groupbys in the query,
+ # and the data has only the metrics
+ df = pd.DataFrame.from_dict({"SUM(num)": {0: 80679663}, "MAX(num)": {0: 37296}})
+ assert (
+ df.to_markdown()
+ == """
+| | SUM(num) | MAX(num) |
+|---:|------------:|-----------:|
+| 0 | 8.06797e+07 | 37296 |
+ """.strip()
+ )
+
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | SUM(num) | MAX(num) |
+|:------------|------------:|-----------:|
+| Total (Sum) | 8.06797e+07 | 37296 |
+ """.strip()
+ )
+
+ # tranpose_pivot and combine_metrics do nothing in this case
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | SUM(num) | MAX(num) |
+|:------------|------------:|-----------:|
+| Total (Sum) | 8.06797e+07 | 37296 |
+ """.strip()
+ )
+
+ # apply_metrics_on_rows will pivot the table, moving the metrics
+ # to rows
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | Total (Sum) |
+|:---------|----------------:|
+| SUM(num) | 8.06797e+07 |
+| MAX(num) | 37296 |
+ """.strip()
+ )
+
+ # when showing totals we only add a column, since adding a row
+ # would be redundant
+ pivoted = pivot_df(
+ df,
+ rows=[],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | ('SUM(num)',) | ('MAX(num)',) | ('Total (Sum)',) |
+|:------------|----------------:|----------------:|-------------------:|
+| Total (Sum) | 8.06797e+07 | 37296 | 8.0717e+07 |
+ """.strip()
+ )
+
+
+def test_pivot_df_single_row_two_metrics():
+ """
+ Pivot table when a single column and 2 metrics are selected.
+ """
+ df = pd.DataFrame.from_dict(
+ {
+ "gender": {0: "girl", 1: "boy"},
+ "SUM(num)": {0: 118065, 1: 47123},
+ "MAX(num)": {0: 2588, 1: 1280},
+ }
+ )
+ assert (
+ df.to_markdown()
+ == """
+| | gender | SUM(num) | MAX(num) |
+|---:|:---------|-----------:|-----------:|
+| 0 | girl | 118065 | 2588 |
+| 1 | boy | 47123 | 1280 |
+ """.strip()
+ )
+
+ pivoted = pivot_df(
+ df,
+ rows=["gender"],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| gender | SUM(num) | MAX(num) |
+|:---------|-----------:|-----------:|
+| boy | 47123 | 1280 |
+| girl | 118065 | 2588 |
+ """.strip()
+ )
+
+ # transpose_pivot
+ pivoted = pivot_df(
+ df,
+ rows=["gender"],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| metric | ('SUM(num)', 'boy') | ('SUM(num)', 'girl') | ('MAX(num)', 'boy') | ('MAX(num)', 'girl') |
+|:------------|----------------------:|-----------------------:|----------------------:|-----------------------:|
+| Total (Sum) | 47123 | 118065 | 1280 | 2588 |
+ """.strip()
+ )
+
+ # combine_metrics does nothing in this case
+ pivoted = pivot_df(
+ df,
+ rows=["gender"],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| gender | SUM(num) | MAX(num) |
+|:---------|-----------:|-----------:|
+| boy | 47123 | 1280 |
+| girl | 118065 | 2588 |
+ """.strip()
+ )
+
+ # show totals
+ pivoted = pivot_df(
+ df,
+ rows=["gender"],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | ('SUM(num)',) | ('MAX(num)',) | ('Total (Sum)',) |
+|:-----------------|----------------:|----------------:|-------------------:|
+| ('boy',) | 47123 | 1280 | 48403 |
+| ('girl',) | 118065 | 2588 | 120653 |
+| ('Total (Sum)',) | 165188 | 3868 | 169056 |
+ """.strip()
+ )
+
+ # apply_metrics_on_rows
+ pivoted = pivot_df(
+ df,
+ rows=["gender"],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | Total (Sum) |
+|:-------------------------|--------------:|
+| ('SUM(num)', 'boy') | 47123 |
+| ('SUM(num)', 'girl') | 118065 |
+| ('SUM(num)', 'Subtotal') | 165188 |
+| ('MAX(num)', 'boy') | 1280 |
+| ('MAX(num)', 'girl') | 2588 |
+| ('MAX(num)', 'Subtotal') | 3868 |
+| ('Total (Sum)', '') | 169056 |
+ """.strip()
+ )
+
+ # apply_metrics_on_rows with combine_metrics
+ pivoted = pivot_df(
+ df,
+ rows=["gender"],
+ columns=[],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=True,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | Total (Sum) |
+|:---------------------|--------------:|
+| ('boy', 'SUM(num)') | 47123 |
+| ('boy', 'MAX(num)') | 1280 |
+| ('boy', 'Subtotal') | 48403 |
+| ('girl', 'SUM(num)') | 118065 |
+| ('girl', 'MAX(num)') | 2588 |
+| ('girl', 'Subtotal') | 120653 |
+| ('Total (Sum)', '') | 169056 |
+ """.strip()
+ )
+
+
+def test_pivot_df_complex():
+ """
+ Pivot table when a column, rows and 2 metrics are selected.
+ """
+ df = pd.DataFrame.from_dict(
+ {
+ "state": {
+ 0: "CA",
+ 1: "CA",
+ 2: "CA",
+ 3: "FL",
+ 4: "CA",
+ 5: "CA",
+ 6: "FL",
+ 7: "FL",
+ 8: "FL",
+ 9: "CA",
+ 10: "FL",
+ 11: "FL",
+ },
+ "gender": {
+ 0: "girl",
+ 1: "boy",
+ 2: "girl",
+ 3: "girl",
+ 4: "girl",
+ 5: "girl",
+ 6: "boy",
+ 7: "girl",
+ 8: "girl",
+ 9: "boy",
+ 10: "boy",
+ 11: "girl",
+ },
+ "name": {
+ 0: "Amy",
+ 1: "Edward",
+ 2: "Sophia",
+ 3: "Amy",
+ 4: "Cindy",
+ 5: "Dawn",
+ 6: "Edward",
+ 7: "Sophia",
+ 8: "Dawn",
+ 9: "Tony",
+ 10: "Tony",
+ 11: "Cindy",
+ },
+ "SUM(num)": {
+ 0: 45426,
+ 1: 31290,
+ 2: 18859,
+ 3: 14740,
+ 4: 14149,
+ 5: 11403,
+ 6: 9395,
+ 7: 7181,
+ 8: 5089,
+ 9: 3765,
+ 10: 2673,
+ 11: 1218,
+ },
+ "MAX(num)": {
+ 0: 2227,
+ 1: 1280,
+ 2: 2588,
+ 3: 854,
+ 4: 842,
+ 5: 1157,
+ 6: 389,
+ 7: 1187,
+ 8: 461,
+ 9: 598,
+ 10: 247,
+ 11: 217,
+ },
+ }
+ )
+ assert (
+ df.to_markdown()
+ == """
+| | state | gender | name | SUM(num) | MAX(num) |
+|---:|:--------|:---------|:-------|-----------:|-----------:|
+| 0 | CA | girl | Amy | 45426 | 2227 |
+| 1 | CA | boy | Edward | 31290 | 1280 |
+| 2 | CA | girl | Sophia | 18859 | 2588 |
+| 3 | FL | girl | Amy | 14740 | 854 |
+| 4 | CA | girl | Cindy | 14149 | 842 |
+| 5 | CA | girl | Dawn | 11403 | 1157 |
+| 6 | FL | boy | Edward | 9395 | 389 |
+| 7 | FL | girl | Sophia | 7181 | 1187 |
+| 8 | FL | girl | Dawn | 5089 | 461 |
+| 9 | CA | boy | Tony | 3765 | 598 |
+| 10 | FL | boy | Tony | 2673 | 247 |
+| 11 | FL | girl | Cindy | 1218 | 217 |
+ """.strip()
+ )
+
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') |
+|:-------------------|---------------------:|---------------------:|---------------------:|---------------------:|
+| ('boy', 'Edward') | 31290 | 9395 | 1280 | 389 |
+| ('boy', 'Tony') | 3765 | 2673 | 598 | 247 |
+| ('girl', 'Amy') | 45426 | 14740 | 2227 | 854 |
+| ('girl', 'Cindy') | 14149 | 1218 | 842 | 217 |
+| ('girl', 'Dawn') | 11403 | 5089 | 1157 | 461 |
+| ('girl', 'Sophia') | 18859 | 7181 | 2588 | 1187 |
+ """.strip()
+ )
+
+ # transpose_pivot
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| state | ('SUM(num)', 'boy', 'Edward') | ('SUM(num)', 'boy', 'Tony') | ('SUM(num)', 'girl', 'Amy') | ('SUM(num)', 'girl', 'Cindy') | ('SUM(num)', 'girl', 'Dawn') | ('SUM(num)', 'girl', 'Sophia') | ('MAX(num)', 'boy', 'Edward') | ('MAX(num)', 'boy', 'Tony') | ('MAX(num)', 'girl', 'Amy') | ('MAX(num)', 'girl', 'Cindy') | ('MAX(num)', 'girl', 'Dawn') | ('MAX(num)', 'girl', 'Sophia') |
+|:--------|--------------------------------:|------------------------------:|------------------------------:|--------------------------------:|-------------------------------:|---------------------------------:|--------------------------------:|------------------------------:|------------------------------:|--------------------------------:|-------------------------------:|---------------------------------:|
+| CA | 31290 | 3765 | 45426 | 14149 | 11403 | 18859 | 1280 | 598 | 2227 | 842 | 1157 | 2588 |
+| FL | 9395 | 2673 | 14740 | 1218 | 5089 | 7181 | 389 | 247 | 854 | 217 | 461 | 1187 |
+ """.strip()
+ )
+
+ # combine_metrics
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | ('CA', 'SUM(num)') | ('CA', 'MAX(num)') | ('FL', 'SUM(num)') | ('FL', 'MAX(num)') |
+|:-------------------|---------------------:|---------------------:|---------------------:|---------------------:|
+| ('boy', 'Edward') | 31290 | 1280 | 9395 | 389 |
+| ('boy', 'Tony') | 3765 | 598 | 2673 | 247 |
+| ('girl', 'Amy') | 45426 | 2227 | 14740 | 854 |
+| ('girl', 'Cindy') | 14149 | 842 | 1218 | 217 |
+| ('girl', 'Dawn') | 11403 | 1157 | 5089 | 461 |
+| ('girl', 'Sophia') | 18859 | 2588 | 7181 | 1187 |
+ """.strip()
+ )
+
+ # show totals
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('SUM(num)', 'Subtotal') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') | ('MAX(num)', 'Subtotal') | ('Total (Sum)', '') |
+|:---------------------|---------------------:|---------------------:|---------------------------:|---------------------:|---------------------:|---------------------------:|----------------------:|
+| ('boy', 'Edward') | 31290 | 9395 | 40685 | 1280 | 389 | 1669 | 42354 |
+| ('boy', 'Tony') | 3765 | 2673 | 6438 | 598 | 247 | 845 | 7283 |
+| ('boy', 'Subtotal') | 35055 | 12068 | 47123 | 1878 | 636 | 2514 | 49637 |
+| ('girl', 'Amy') | 45426 | 14740 | 60166 | 2227 | 854 | 3081 | 63247 |
+| ('girl', 'Cindy') | 14149 | 1218 | 15367 | 842 | 217 | 1059 | 16426 |
+| ('girl', 'Dawn') | 11403 | 5089 | 16492 | 1157 | 461 | 1618 | 18110 |
+| ('girl', 'Sophia') | 18859 | 7181 | 26040 | 2588 | 1187 | 3775 | 29815 |
+| ('girl', 'Subtotal') | 89837 | 28228 | 118065 | 6814 | 2719 | 9533 | 127598 |
+| ('Total (Sum)', '') | 124892 | 40296 | 165188 | 8692 | 3355 | 12047 | 177235 |
+ """.strip()
+ )
+
+ # apply_metrics_on_rows
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | CA | FL |
+|:-------------------------------|------:|------:|
+| ('SUM(num)', 'boy', 'Edward') | 31290 | 9395 |
+| ('SUM(num)', 'boy', 'Tony') | 3765 | 2673 |
+| ('SUM(num)', 'girl', 'Amy') | 45426 | 14740 |
+| ('SUM(num)', 'girl', 'Cindy') | 14149 | 1218 |
+| ('SUM(num)', 'girl', 'Dawn') | 11403 | 5089 |
+| ('SUM(num)', 'girl', 'Sophia') | 18859 | 7181 |
+| ('MAX(num)', 'boy', 'Edward') | 1280 | 389 |
+| ('MAX(num)', 'boy', 'Tony') | 598 | 247 |
+| ('MAX(num)', 'girl', 'Amy') | 2227 | 854 |
+| ('MAX(num)', 'girl', 'Cindy') | 842 | 217 |
+| ('MAX(num)', 'girl', 'Dawn') | 1157 | 461 |
+| ('MAX(num)', 'girl', 'Sophia') | 2588 | 1187 |
+ """.strip()
+ )
+
+ # apply_metrics_on_rows with combine_metrics
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=False,
+ combine_metrics=True,
+ show_rows_total=False,
+ show_columns_total=False,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | CA | FL |
+|:-------------------------------|------:|------:|
+| ('boy', 'Edward', 'SUM(num)') | 31290 | 9395 |
+| ('boy', 'Edward', 'MAX(num)') | 1280 | 389 |
+| ('boy', 'Tony', 'SUM(num)') | 3765 | 2673 |
+| ('boy', 'Tony', 'MAX(num)') | 598 | 247 |
+| ('girl', 'Amy', 'SUM(num)') | 45426 | 14740 |
+| ('girl', 'Amy', 'MAX(num)') | 2227 | 854 |
+| ('girl', 'Cindy', 'SUM(num)') | 14149 | 1218 |
+| ('girl', 'Cindy', 'MAX(num)') | 842 | 217 |
+| ('girl', 'Dawn', 'SUM(num)') | 11403 | 5089 |
+| ('girl', 'Dawn', 'MAX(num)') | 1157 | 461 |
+| ('girl', 'Sophia', 'SUM(num)') | 18859 | 7181 |
+| ('girl', 'Sophia', 'MAX(num)') | 2588 | 1187 |
+ """.strip()
+ )
+
+ # everything
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum",
+ transpose_pivot=True,
+ combine_metrics=True,
+ show_rows_total=True,
+ show_columns_total=True,
+ apply_metrics_on_rows=True,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | ('boy', 'Edward') | ('boy', 'Tony') | ('boy', 'Subtotal') | ('girl', 'Amy') | ('girl', 'Cindy') | ('girl', 'Dawn') | ('girl', 'Sophia') | ('girl', 'Subtotal') | ('Total (Sum)', '') |
+|:--------------------|--------------------:|------------------:|----------------------:|------------------:|--------------------:|-------------------:|---------------------:|-----------------------:|----------------------:|
+| ('CA', 'SUM(num)') | 31290 | 3765 | 35055 | 45426 | 14149 | 11403 | 18859 | 89837 | 124892 |
+| ('CA', 'MAX(num)') | 1280 | 598 | 1878 | 2227 | 842 | 1157 | 2588 | 6814 | 8692 |
+| ('CA', 'Subtotal') | 32570 | 4363 | 36933 | 47653 | 14991 | 12560 | 21447 | 96651 | 133584 |
+| ('FL', 'SUM(num)') | 9395 | 2673 | 12068 | 14740 | 1218 | 5089 | 7181 | 28228 | 40296 |
+| ('FL', 'MAX(num)') | 389 | 247 | 636 | 854 | 217 | 461 | 1187 | 2719 | 3355 |
+| ('FL', 'Subtotal') | 9784 | 2920 | 12704 | 15594 | 1435 | 5550 | 8368 | 30947 | 43651 |
+| ('Total (Sum)', '') | 42354 | 7283 | 49637 | 63247 | 16426 | 18110 | 29815 | 127598 | 177235 |
+ """.strip()
+ )
+
+ # fraction
+ pivoted = pivot_df(
+ df,
+ rows=["gender", "name"],
+ columns=["state"],
+ metrics=["SUM(num)", "MAX(num)"],
+ aggfunc="Sum as Fraction of Columns",
+ transpose_pivot=False,
+ combine_metrics=False,
+ show_rows_total=False,
+ show_columns_total=True,
+ apply_metrics_on_rows=False,
+ )
+ assert (
+ pivoted.to_markdown()
+ == """
+| | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') |
+|:-------------------------------------------|---------------------:|---------------------:|---------------------:|---------------------:|
+| ('boy', 'Edward') | 0.250536 | 0.23315 | 0.147262 | 0.115946 |
+| ('boy', 'Tony') | 0.030146 | 0.0663341 | 0.0687989 | 0.0736215 |
+| ('boy', 'Subtotal') | 0.280683 | 0.299484 | 0.216061 | 0.189568 |
+| ('girl', 'Amy') | 0.363722 | 0.365793 | 0.256213 | 0.254545 |
+| ('girl', 'Cindy') | 0.11329 | 0.0302263 | 0.0968707 | 0.0646796 |
+| ('girl', 'Dawn') | 0.0913029 | 0.12629 | 0.133111 | 0.137407 |
+| ('girl', 'Sophia') | 0.151002 | 0.178206 | 0.297745 | 0.3538 |
+| ('girl', 'Subtotal') | 0.719317 | 0.700516 | 0.783939 | 0.810432 |
+| ('Total (Sum as Fraction of Columns)', '') | 1 | 1 | 1 | 1 |
+ """.strip()
+ )