You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/02/17 12:07:26 UTC
[superset] branch master updated: refactor: postprocessing move to unit test (#18779)
This is an automated email from the ASF dual-hosted git repository.
yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 30a9d14 refactor: postprocessing move to unit test (#18779)
30a9d14 is described below
commit 30a9d14639fe3072a7886ac201777e62077e532a
Author: Yongjie Zhao <yo...@gmail.com>
AuthorDate: Thu Feb 17 20:05:41 2022 +0800
refactor: postprocessing move to unit test (#18779)
---
.../pandas_postprocessing_tests.py | 1098 --------------------
.../fixtures/dataframes.py | 0
tests/unit_tests/pandas_postprocessing/__init__.py | 16 +
.../pandas_postprocessing/test_aggregate.py | 40 +
.../pandas_postprocessing/test_boxplot.py | 126 +++
.../pandas_postprocessing/test_compare.py | 62 ++
.../pandas_postprocessing/test_contribution.py | 69 ++
tests/unit_tests/pandas_postprocessing/test_cum.py | 97 ++
.../unit_tests/pandas_postprocessing/test_diff.py | 50 +
.../pandas_postprocessing/test_geography.py | 90 ++
.../unit_tests/pandas_postprocessing/test_pivot.py | 266 +++++
.../pandas_postprocessing/test_prophet.py | 114 ++
.../pandas_postprocessing/test_resample.py | 107 ++
.../pandas_postprocessing/test_rolling.py | 147 +++
.../pandas_postprocessing/test_select.py | 55 +
.../unit_tests/pandas_postprocessing/test_sort.py | 30 +
tests/unit_tests/pandas_postprocessing/utils.py | 55 +
17 files changed, 1324 insertions(+), 1098 deletions(-)
diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py
deleted file mode 100644
index 50612e1..0000000
--- a/tests/integration_tests/pandas_postprocessing_tests.py
+++ /dev/null
@@ -1,1098 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# isort:skip_file
-from datetime import datetime
-from importlib.util import find_spec
-import math
-from typing import Any, List, Optional
-
-import numpy as np
-from pandas import DataFrame, Series, Timestamp, to_datetime
-import pytest
-
-from superset.exceptions import QueryObjectValidationError
-from superset.utils import pandas_postprocessing as proc
-from superset.utils.core import (
- DTTM_ALIAS,
- PostProcessingContributionOrientation,
- PostProcessingBoxplotWhiskerType,
-)
-
-from .base_tests import SupersetTestCase
-from .fixtures.dataframes import (
- categories_df,
- single_metric_df,
- multiple_metrics_df,
- lonlat_df,
- names_df,
- timeseries_df,
- prophet_df,
- timeseries_df2,
-)
-
-AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
-AGGREGATES_MULTIPLE = {
- "idx_nulls": {"operator": "sum"},
- "asc_idx": {"operator": "mean"},
-}
-
-
-def series_to_list(series: Series) -> List[Any]:
- """
- Converts a `Series` to a regular list, and replaces non-numeric values to
- Nones.
-
- :param series: Series to convert
- :return: list without nan or inf
- """
- return [
- None
- if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
- else val
- for val in series.tolist()
- ]
-
-
-def round_floats(
- floats: List[Optional[float]], precision: int
-) -> List[Optional[float]]:
- """
- Round list of floats to certain precision
-
- :param floats: floats to round
- :param precision: intended decimal precision
- :return: rounded floats
- """
- return [round(val, precision) if val else None for val in floats]
-
-
-class TestPostProcessing(SupersetTestCase):
- def test_flatten_column_after_pivot(self):
- """
- Test pivot column flattening function
- """
- # single aggregate cases
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_SINGLE, column="idx_nulls",
- ),
- "idx_nulls",
- )
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_SINGLE, column=1234,
- ),
- "1234",
- )
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"),
- ),
- "2020-09-29 00:00:00",
- )
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_SINGLE, column="idx_nulls",
- ),
- "idx_nulls",
- )
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
- ),
- "col1",
- )
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234),
- ),
- "col1, 1234",
- )
-
- # Multiple aggregate cases
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
- ),
- "idx_nulls, asc_idx, col1",
- )
- self.assertEqual(
- proc._flatten_column_after_pivot(
- aggregates=AGGREGATES_MULTIPLE,
- column=("idx_nulls", "asc_idx", "col1", 1234),
- ),
- "idx_nulls, asc_idx, col1, 1234",
- )
-
- def test_pivot_without_columns(self):
- """
- Make sure pivot without columns returns correct DataFrame
- """
- df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
- self.assertListEqual(
- df.columns.tolist(), ["name", "idx_nulls"],
- )
- self.assertEqual(len(df), 101)
- self.assertEqual(df.sum()[1], 1050)
-
- def test_pivot_with_single_column(self):
- """
- Make sure pivot with single column returns correct DataFrame
- """
- df = proc.pivot(
- df=categories_df,
- index=["name"],
- columns=["category"],
- aggregates=AGGREGATES_SINGLE,
- )
- self.assertListEqual(
- df.columns.tolist(), ["name", "cat0", "cat1", "cat2"],
- )
- self.assertEqual(len(df), 101)
- self.assertEqual(df.sum()[1], 315)
-
- df = proc.pivot(
- df=categories_df,
- index=["dept"],
- columns=["category"],
- aggregates=AGGREGATES_SINGLE,
- )
- self.assertListEqual(
- df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"],
- )
- self.assertEqual(len(df), 5)
-
- def test_pivot_with_multiple_columns(self):
- """
- Make sure pivot with multiple columns returns correct DataFrame
- """
- df = proc.pivot(
- df=categories_df,
- index=["name"],
- columns=["category", "dept"],
- aggregates=AGGREGATES_SINGLE,
- )
- self.assertEqual(len(df.columns), 1 + 3 * 5) # index + possible permutations
-
- def test_pivot_fill_values(self):
- """
- Make sure pivot with fill values returns correct DataFrame
- """
- df = proc.pivot(
- df=categories_df,
- index=["name"],
- columns=["category"],
- metric_fill_value=1,
- aggregates={"idx_nulls": {"operator": "sum"}},
- )
- self.assertEqual(df.sum()[1], 382)
-
- def test_pivot_fill_column_values(self):
- """
- Make sure pivot witn null column names returns correct DataFrame
- """
- df_copy = categories_df.copy()
- df_copy["category"] = None
- df = proc.pivot(
- df=df_copy,
- index=["name"],
- columns=["category"],
- aggregates={"idx_nulls": {"operator": "sum"}},
- )
- assert len(df) == 101
- assert df.columns.tolist() == ["name", "<NULL>"]
-
- def test_pivot_exceptions(self):
- """
- Make sure pivot raises correct Exceptions
- """
- # Missing index
- self.assertRaises(
- TypeError,
- proc.pivot,
- df=categories_df,
- columns=["dept"],
- aggregates=AGGREGATES_SINGLE,
- )
-
- # invalid index reference
- self.assertRaises(
- QueryObjectValidationError,
- proc.pivot,
- df=categories_df,
- index=["abc"],
- columns=["dept"],
- aggregates=AGGREGATES_SINGLE,
- )
-
- # invalid column reference
- self.assertRaises(
- QueryObjectValidationError,
- proc.pivot,
- df=categories_df,
- index=["dept"],
- columns=["abc"],
- aggregates=AGGREGATES_SINGLE,
- )
-
- # invalid aggregate options
- self.assertRaises(
- QueryObjectValidationError,
- proc.pivot,
- df=categories_df,
- index=["name"],
- columns=["category"],
- aggregates={"idx_nulls": {}},
- )
-
- def test_pivot_eliminate_cartesian_product_columns(self):
- # single metric
- mock_df = DataFrame(
- {
- "dttm": to_datetime(["2019-01-01", "2019-01-01"]),
- "a": [0, 1],
- "b": [0, 1],
- "metric": [9, np.NAN],
- }
- )
-
- df = proc.pivot(
- df=mock_df,
- index=["dttm"],
- columns=["a", "b"],
- aggregates={"metric": {"operator": "mean"}},
- drop_missing_columns=False,
- )
- self.assertEqual(list(df.columns), ["dttm", "0, 0", "1, 1"])
- self.assertTrue(np.isnan(df["1, 1"][0]))
-
- # multiple metrics
- mock_df = DataFrame(
- {
- "dttm": to_datetime(["2019-01-01", "2019-01-01"]),
- "a": [0, 1],
- "b": [0, 1],
- "metric": [9, np.NAN],
- "metric2": [10, 11],
- }
- )
-
- df = proc.pivot(
- df=mock_df,
- index=["dttm"],
- columns=["a", "b"],
- aggregates={
- "metric": {"operator": "mean"},
- "metric2": {"operator": "mean"},
- },
- drop_missing_columns=False,
- )
- self.assertEqual(
- list(df.columns),
- ["dttm", "metric, 0, 0", "metric, 1, 1", "metric2, 0, 0", "metric2, 1, 1"],
- )
- self.assertTrue(np.isnan(df["metric, 1, 1"][0]))
-
- def test_pivot_without_flatten_columns_and_reset_index(self):
- df = proc.pivot(
- df=single_metric_df,
- index=["dttm"],
- columns=["country"],
- aggregates={"sum_metric": {"operator": "sum"}},
- flatten_columns=False,
- reset_index=False,
- )
- # metric
- # country UK US
- # dttm
- # 2019-01-01 5 6
- # 2019-01-02 7 8
- assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
- assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
-
- def test_aggregate(self):
- aggregates = {
- "asc sum": {"column": "asc_idx", "operator": "sum"},
- "asc q2": {
- "column": "asc_idx",
- "operator": "percentile",
- "options": {"q": 75},
- },
- "desc q1": {
- "column": "desc_idx",
- "operator": "percentile",
- "options": {"q": 25},
- },
- }
- df = proc.aggregate(
- df=categories_df, groupby=["constant"], aggregates=aggregates
- )
- self.assertListEqual(
- df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"]
- )
- self.assertEqual(series_to_list(df["asc sum"])[0], 5050)
- self.assertEqual(series_to_list(df["asc q2"])[0], 75)
- self.assertEqual(series_to_list(df["desc q1"])[0], 25)
-
- def test_sort(self):
- df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False})
- self.assertEqual(96, series_to_list(df["asc_idx"])[1])
-
- self.assertRaises(
- QueryObjectValidationError, proc.sort, df=df, columns={"abc": True}
- )
-
- def test_rolling(self):
- # sum rolling type
- post_df = proc.rolling(
- df=timeseries_df,
- columns={"y": "y"},
- rolling_type="sum",
- window=2,
- min_periods=0,
- )
-
- self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
- self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0])
-
- # mean rolling type with alias
- post_df = proc.rolling(
- df=timeseries_df,
- rolling_type="mean",
- columns={"y": "y_mean"},
- window=10,
- min_periods=0,
- )
- self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"])
- self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5])
-
- # count rolling type
- post_df = proc.rolling(
- df=timeseries_df,
- rolling_type="count",
- columns={"y": "y"},
- window=10,
- min_periods=0,
- )
- self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
- self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
-
- # quantile rolling type
- post_df = proc.rolling(
- df=timeseries_df,
- columns={"y": "q1"},
- rolling_type="quantile",
- rolling_type_options={"quantile": 0.25},
- window=10,
- min_periods=0,
- )
- self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"])
- self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75])
-
- # incorrect rolling type
- self.assertRaises(
- QueryObjectValidationError,
- proc.rolling,
- df=timeseries_df,
- columns={"y": "y"},
- rolling_type="abc",
- window=2,
- )
-
- # incorrect rolling type options
- self.assertRaises(
- QueryObjectValidationError,
- proc.rolling,
- df=timeseries_df,
- columns={"y": "y"},
- rolling_type="quantile",
- rolling_type_options={"abc": 123},
- window=2,
- )
-
- def test_rolling_with_pivot_df_and_single_metric(self):
- pivot_df = proc.pivot(
- df=single_metric_df,
- index=["dttm"],
- columns=["country"],
- aggregates={"sum_metric": {"operator": "sum"}},
- flatten_columns=False,
- reset_index=False,
- )
- rolling_df = proc.rolling(
- df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
- )
- # dttm UK US
- # 0 2019-01-01 5 6
- # 1 2019-01-02 12 14
- assert rolling_df["UK"].to_list() == [5.0, 12.0]
- assert rolling_df["US"].to_list() == [6.0, 14.0]
- assert (
- rolling_df["dttm"].to_list()
- == to_datetime(["2019-01-01", "2019-01-02",]).to_list()
- )
-
- rolling_df = proc.rolling(
- df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
- )
- assert rolling_df.empty is True
-
- def test_rolling_with_pivot_df_and_multiple_metrics(self):
- pivot_df = proc.pivot(
- df=multiple_metrics_df,
- index=["dttm"],
- columns=["country"],
- aggregates={
- "sum_metric": {"operator": "sum"},
- "count_metric": {"operator": "sum"},
- },
- flatten_columns=False,
- reset_index=False,
- )
- rolling_df = proc.rolling(
- df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
- )
- # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
- # 0 2019-01-01 1.0 2.0 5.0 6.0
- # 1 2019-01-02 4.0 6.0 12.0 14.0
- assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
- assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
- assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
- assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
- assert (
- rolling_df["dttm"].to_list()
- == to_datetime(["2019-01-01", "2019-01-02",]).to_list()
- )
-
- def test_select(self):
- # reorder columns
- post_df = proc.select(df=timeseries_df, columns=["y", "label"])
- self.assertListEqual(post_df.columns.tolist(), ["y", "label"])
-
- # one column
- post_df = proc.select(df=timeseries_df, columns=["label"])
- self.assertListEqual(post_df.columns.tolist(), ["label"])
-
- # rename and select one column
- post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
- self.assertListEqual(post_df.columns.tolist(), ["y1"])
-
- # rename one and leave one unchanged
- post_df = proc.select(df=timeseries_df, rename={"y": "y1"})
- self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
-
- # drop one column
- post_df = proc.select(df=timeseries_df, exclude=["label"])
- self.assertListEqual(post_df.columns.tolist(), ["y"])
-
- # rename and drop one column
- post_df = proc.select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
- self.assertListEqual(post_df.columns.tolist(), ["y1"])
-
- # invalid columns
- self.assertRaises(
- QueryObjectValidationError,
- proc.select,
- df=timeseries_df,
- columns=["abc"],
- rename={"abc": "qwerty"},
- )
-
- # select renamed column by new name
- self.assertRaises(
- QueryObjectValidationError,
- proc.select,
- df=timeseries_df,
- columns=["label_new"],
- rename={"label": "label_new"},
- )
-
- def test_diff(self):
- # overwrite column
- post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
- self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
- self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0])
-
- # add column
- post_df = proc.diff(df=timeseries_df, columns={"y": "y1"})
- self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"])
- self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
- self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0])
-
- # look ahead
- post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
- self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None])
-
- # invalid column reference
- self.assertRaises(
- QueryObjectValidationError,
- proc.diff,
- df=timeseries_df,
- columns={"abc": "abc"},
- )
-
- # diff by columns
- post_df = proc.diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1)
- self.assertListEqual(post_df.columns.tolist(), ["label", "y", "z"])
- self.assertListEqual(series_to_list(post_df["z"]), [0.0, 2.0, 8.0, 6.0])
-
- def test_compare(self):
- # `difference` comparison
- post_df = proc.compare(
- df=timeseries_df2,
- source_columns=["y"],
- compare_columns=["z"],
- compare_type="difference",
- )
- self.assertListEqual(
- post_df.columns.tolist(), ["label", "y", "z", "difference__y__z",]
- )
- self.assertListEqual(
- series_to_list(post_df["difference__y__z"]), [0.0, -2.0, -8.0, -6.0],
- )
-
- # drop original columns
- post_df = proc.compare(
- df=timeseries_df2,
- source_columns=["y"],
- compare_columns=["z"],
- compare_type="difference",
- drop_original_columns=True,
- )
- self.assertListEqual(post_df.columns.tolist(), ["label", "difference__y__z",])
-
- # `percentage` comparison
- post_df = proc.compare(
- df=timeseries_df2,
- source_columns=["y"],
- compare_columns=["z"],
- compare_type="percentage",
- )
- self.assertListEqual(
- post_df.columns.tolist(), ["label", "y", "z", "percentage__y__z",]
- )
- self.assertListEqual(
- series_to_list(post_df["percentage__y__z"]), [0.0, -0.5, -0.8, -0.75],
- )
-
- # `ratio` comparison
- post_df = proc.compare(
- df=timeseries_df2,
- source_columns=["y"],
- compare_columns=["z"],
- compare_type="ratio",
- )
- self.assertListEqual(
- post_df.columns.tolist(), ["label", "y", "z", "ratio__y__z",]
- )
- self.assertListEqual(
- series_to_list(post_df["ratio__y__z"]), [1.0, 0.5, 0.2, 0.25],
- )
-
- def test_cum(self):
- # create new column (cumsum)
- post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
- self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"])
- self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"])
- self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
- self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0])
-
- # overwrite column (cumprod)
- post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
- self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
- self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0])
-
- # overwrite column (cummin)
- post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
- self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
- self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0])
-
- # invalid operator
- self.assertRaises(
- QueryObjectValidationError,
- proc.cum,
- df=timeseries_df,
- columns={"y": "y"},
- operator="abc",
- )
-
- def test_cum_with_pivot_df_and_single_metric(self):
- pivot_df = proc.pivot(
- df=single_metric_df,
- index=["dttm"],
- columns=["country"],
- aggregates={"sum_metric": {"operator": "sum"}},
- flatten_columns=False,
- reset_index=False,
- )
- cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
- # dttm UK US
- # 0 2019-01-01 5 6
- # 1 2019-01-02 12 14
- assert cum_df["UK"].to_list() == [5.0, 12.0]
- assert cum_df["US"].to_list() == [6.0, 14.0]
- assert (
- cum_df["dttm"].to_list()
- == to_datetime(["2019-01-01", "2019-01-02",]).to_list()
- )
-
- def test_cum_with_pivot_df_and_multiple_metrics(self):
- pivot_df = proc.pivot(
- df=multiple_metrics_df,
- index=["dttm"],
- columns=["country"],
- aggregates={
- "sum_metric": {"operator": "sum"},
- "count_metric": {"operator": "sum"},
- },
- flatten_columns=False,
- reset_index=False,
- )
- cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
- # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
- # 0 2019-01-01 1 2 5 6
- # 1 2019-01-02 4 6 12 14
- assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
- assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
- assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
- assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
- assert (
- cum_df["dttm"].to_list()
- == to_datetime(["2019-01-01", "2019-01-02",]).to_list()
- )
-
- def test_geohash_decode(self):
- # decode lon/lat from geohash
- post_df = proc.geohash_decode(
- df=lonlat_df[["city", "geohash"]],
- geohash="geohash",
- latitude="latitude",
- longitude="longitude",
- )
- self.assertListEqual(
- sorted(post_df.columns.tolist()),
- sorted(["city", "geohash", "latitude", "longitude"]),
- )
- self.assertListEqual(
- round_floats(series_to_list(post_df["longitude"]), 6),
- round_floats(series_to_list(lonlat_df["longitude"]), 6),
- )
- self.assertListEqual(
- round_floats(series_to_list(post_df["latitude"]), 6),
- round_floats(series_to_list(lonlat_df["latitude"]), 6),
- )
-
- def test_geohash_encode(self):
- # encode lon/lat into geohash
- post_df = proc.geohash_encode(
- df=lonlat_df[["city", "latitude", "longitude"]],
- latitude="latitude",
- longitude="longitude",
- geohash="geohash",
- )
- self.assertListEqual(
- sorted(post_df.columns.tolist()),
- sorted(["city", "geohash", "latitude", "longitude"]),
- )
- self.assertListEqual(
- series_to_list(post_df["geohash"]), series_to_list(lonlat_df["geohash"]),
- )
-
- def test_geodetic_parse(self):
- # parse geodetic string with altitude into lon/lat/altitude
- post_df = proc.geodetic_parse(
- df=lonlat_df[["city", "geodetic"]],
- geodetic="geodetic",
- latitude="latitude",
- longitude="longitude",
- altitude="altitude",
- )
- self.assertListEqual(
- sorted(post_df.columns.tolist()),
- sorted(["city", "geodetic", "latitude", "longitude", "altitude"]),
- )
- self.assertListEqual(
- series_to_list(post_df["longitude"]),
- series_to_list(lonlat_df["longitude"]),
- )
- self.assertListEqual(
- series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
- )
- self.assertListEqual(
- series_to_list(post_df["altitude"]), series_to_list(lonlat_df["altitude"]),
- )
-
- # parse geodetic string into lon/lat
- post_df = proc.geodetic_parse(
- df=lonlat_df[["city", "geodetic"]],
- geodetic="geodetic",
- latitude="latitude",
- longitude="longitude",
- )
- self.assertListEqual(
- sorted(post_df.columns.tolist()),
- sorted(["city", "geodetic", "latitude", "longitude"]),
- )
- self.assertListEqual(
- series_to_list(post_df["longitude"]),
- series_to_list(lonlat_df["longitude"]),
- )
- self.assertListEqual(
- series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
- )
-
- def test_contribution(self):
- df = DataFrame(
- {
- DTTM_ALIAS: [
- datetime(2020, 7, 16, 14, 49),
- datetime(2020, 7, 16, 14, 50),
- ],
- "a": [1, 3],
- "b": [1, 9],
- }
- )
- with pytest.raises(QueryObjectValidationError, match="not numeric"):
- proc.contribution(df, columns=[DTTM_ALIAS])
-
- with pytest.raises(QueryObjectValidationError, match="same length"):
- proc.contribution(df, columns=["a"], rename_columns=["aa", "bb"])
-
- # cell contribution across row
- processed_df = proc.contribution(
- df, orientation=PostProcessingContributionOrientation.ROW,
- )
- self.assertListEqual(processed_df.columns.tolist(), [DTTM_ALIAS, "a", "b"])
- self.assertListEqual(processed_df["a"].tolist(), [0.5, 0.25])
- self.assertListEqual(processed_df["b"].tolist(), [0.5, 0.75])
-
- # cell contribution across column without temporal column
- df.pop(DTTM_ALIAS)
- processed_df = proc.contribution(
- df, orientation=PostProcessingContributionOrientation.COLUMN
- )
- self.assertListEqual(processed_df.columns.tolist(), ["a", "b"])
- self.assertListEqual(processed_df["a"].tolist(), [0.25, 0.75])
- self.assertListEqual(processed_df["b"].tolist(), [0.1, 0.9])
-
- # contribution only on selected columns
- processed_df = proc.contribution(
- df,
- orientation=PostProcessingContributionOrientation.COLUMN,
- columns=["a"],
- rename_columns=["pct_a"],
- )
- self.assertListEqual(processed_df.columns.tolist(), ["a", "b", "pct_a"])
- self.assertListEqual(processed_df["a"].tolist(), [1, 3])
- self.assertListEqual(processed_df["b"].tolist(), [1, 9])
- self.assertListEqual(processed_df["pct_a"].tolist(), [0.25, 0.75])
-
- def test_prophet_valid(self):
- pytest.importorskip("prophet")
-
- df = proc.prophet(
- df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9
- )
- columns = {column for column in df.columns}
- assert columns == {
- DTTM_ALIAS,
- "a__yhat",
- "a__yhat_upper",
- "a__yhat_lower",
- "a",
- "b__yhat",
- "b__yhat_upper",
- "b__yhat_lower",
- "b",
- }
- assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
- assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
- assert len(df) == 7
-
- df = proc.prophet(
- df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9
- )
- assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
- assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
- assert len(df) == 9
-
- def test_prophet_valid_zero_periods(self):
- pytest.importorskip("prophet")
-
- df = proc.prophet(
- df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9
- )
- columns = {column for column in df.columns}
- assert columns == {
- DTTM_ALIAS,
- "a__yhat",
- "a__yhat_upper",
- "a__yhat_lower",
- "a",
- "b__yhat",
- "b__yhat_upper",
- "b__yhat_lower",
- "b",
- }
- assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
- assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
- assert len(df) == 4
-
- def test_prophet_import(self):
- prophet = find_spec("prophet")
- if prophet is None:
- with pytest.raises(QueryObjectValidationError):
- proc.prophet(
- df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9
- )
-
- def test_prophet_missing_temporal_column(self):
- df = prophet_df.drop(DTTM_ALIAS, axis=1)
-
- self.assertRaises(
- QueryObjectValidationError,
- proc.prophet,
- df=df,
- time_grain="P1M",
- periods=3,
- confidence_interval=0.9,
- )
-
- def test_prophet_incorrect_confidence_interval(self):
- self.assertRaises(
- QueryObjectValidationError,
- proc.prophet,
- df=prophet_df,
- time_grain="P1M",
- periods=3,
- confidence_interval=0.0,
- )
-
- self.assertRaises(
- QueryObjectValidationError,
- proc.prophet,
- df=prophet_df,
- time_grain="P1M",
- periods=3,
- confidence_interval=1.0,
- )
-
- def test_prophet_incorrect_periods(self):
- self.assertRaises(
- QueryObjectValidationError,
- proc.prophet,
- df=prophet_df,
- time_grain="P1M",
- periods=-1,
- confidence_interval=0.8,
- )
-
- def test_prophet_incorrect_time_grain(self):
- self.assertRaises(
- QueryObjectValidationError,
- proc.prophet,
- df=prophet_df,
- time_grain="yearly",
- periods=10,
- confidence_interval=0.8,
- )
-
- def test_boxplot_tukey(self):
- df = proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
- metrics=["cars"],
- )
- columns = {column for column in df.columns}
- assert columns == {
- "cars__mean",
- "cars__median",
- "cars__q1",
- "cars__q3",
- "cars__max",
- "cars__min",
- "cars__count",
- "cars__outliers",
- "region",
- }
- assert len(df) == 4
-
- def test_boxplot_min_max(self):
- df = proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.MINMAX,
- metrics=["cars"],
- )
- columns = {column for column in df.columns}
- assert columns == {
- "cars__mean",
- "cars__median",
- "cars__q1",
- "cars__q3",
- "cars__max",
- "cars__min",
- "cars__count",
- "cars__outliers",
- "region",
- }
- assert len(df) == 4
-
- def test_boxplot_percentile(self):
- df = proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
- metrics=["cars"],
- percentiles=[1, 99],
- )
- columns = {column for column in df.columns}
- assert columns == {
- "cars__mean",
- "cars__median",
- "cars__q1",
- "cars__q3",
- "cars__max",
- "cars__min",
- "cars__count",
- "cars__outliers",
- "region",
- }
- assert len(df) == 4
-
- def test_boxplot_percentile_incorrect_params(self):
- with pytest.raises(QueryObjectValidationError):
- proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
- metrics=["cars"],
- )
-
- with pytest.raises(QueryObjectValidationError):
- proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
- metrics=["cars"],
- percentiles=[10],
- )
-
- with pytest.raises(QueryObjectValidationError):
- proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
- metrics=["cars"],
- percentiles=[90, 10],
- )
-
- with pytest.raises(QueryObjectValidationError):
- proc.boxplot(
- df=names_df,
- groupby=["region"],
- whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
- metrics=["cars"],
- percentiles=[10, 90, 10],
- )
-
- def test_resample(self):
- df = timeseries_df.copy()
- df.index.name = "time_column"
- df.reset_index(inplace=True)
-
- post_df = proc.resample(
- df=df, rule="1D", method="ffill", time_column="time_column",
- )
- self.assertListEqual(
- post_df["label"].tolist(), ["x", "y", "y", "y", "z", "z", "q"]
- )
- self.assertListEqual(post_df["y"].tolist(), [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0])
-
- post_df = proc.resample(
- df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
- )
- self.assertListEqual(post_df["label"].tolist(), ["x", "y", 0, 0, "z", 0, "q"])
- self.assertListEqual(post_df["y"].tolist(), [1.0, 2.0, 0, 0, 3.0, 0, 4.0])
-
- def test_resample_with_groupby(self):
- """
-The Dataframe contains a timestamp column, a string column and a numeric column.
- __timestamp city val
-0 2022-01-13 Chicago 6.0
-1 2022-01-13 LA 5.0
-2 2022-01-13 NY 4.0
-3 2022-01-11 Chicago 3.0
-4 2022-01-11 LA 2.0
-5 2022-01-11 NY 1.0
- """
- df = DataFrame(
- {
- "__timestamp": to_datetime(
- [
- "2022-01-13",
- "2022-01-13",
- "2022-01-13",
- "2022-01-11",
- "2022-01-11",
- "2022-01-11",
- ]
- ),
- "city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"],
- "val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
- }
- )
- post_df = proc.resample(
- df=df,
- rule="1D",
- method="asfreq",
- fill_value=0,
- time_column="__timestamp",
- groupby_columns=("city",),
- )
- assert list(post_df.columns) == [
- "__timestamp",
- "city",
- "val",
- ]
- assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
- ["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
- )
- assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
-
- # should raise error when get a non-existent column
- with pytest.raises(QueryObjectValidationError):
- proc.resample(
- df=df,
- rule="1D",
- method="asfreq",
- fill_value=0,
- time_column="__timestamp",
- groupby_columns=("city", "unkonw_column",),
- )
-
- # should raise error when get a None value in groupby list
- with pytest.raises(QueryObjectValidationError):
- proc.resample(
- df=df,
- rule="1D",
- method="asfreq",
- fill_value=0,
- time_column="__timestamp",
- groupby_columns=("city", None,),
- )
diff --git a/tests/integration_tests/fixtures/dataframes.py b/tests/unit_tests/fixtures/dataframes.py
similarity index 100%
rename from tests/integration_tests/fixtures/dataframes.py
rename to tests/unit_tests/fixtures/dataframes.py
diff --git a/tests/unit_tests/pandas_postprocessing/__init__.py b/tests/unit_tests/pandas_postprocessing/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/pandas_postprocessing/test_aggregate.py b/tests/unit_tests/pandas_postprocessing/test_aggregate.py
new file mode 100644
index 0000000..69d42e3
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_aggregate.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import aggregate
+from tests.unit_tests.fixtures.dataframes import categories_df
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_aggregate():
+ aggregates = {
+ "asc sum": {"column": "asc_idx", "operator": "sum"},
+ "asc q2": {
+ "column": "asc_idx",
+ "operator": "percentile",
+ "options": {"q": 75},
+ },
+ "desc q1": {
+ "column": "desc_idx",
+ "operator": "percentile",
+ "options": {"q": 25},
+ },
+ }
+ df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
+ assert df.columns.tolist() == ["constant", "asc sum", "asc q2", "desc q1"]
+ assert series_to_list(df["asc sum"])[0] == 5050
+ assert series_to_list(df["asc q2"])[0] == 75
+ assert series_to_list(df["desc q1"])[0] == 25
diff --git a/tests/unit_tests/pandas_postprocessing/test_boxplot.py b/tests/unit_tests/pandas_postprocessing/test_boxplot.py
new file mode 100644
index 0000000..247aba0
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_boxplot.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.core import PostProcessingBoxplotWhiskerType
+from superset.utils.pandas_postprocessing import boxplot
+from tests.unit_tests.fixtures.dataframes import names_df
+
+
+def test_boxplot_tukey():
+ df = boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
+ metrics=["cars"],
+ )
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
+
+
+def test_boxplot_min_max():
+ df = boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.MINMAX,
+ metrics=["cars"],
+ )
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
+
+
+def test_boxplot_percentile():
+ df = boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[1, 99],
+ )
+ columns = {column for column in df.columns}
+ assert columns == {
+ "cars__mean",
+ "cars__median",
+ "cars__q1",
+ "cars__q3",
+ "cars__max",
+ "cars__min",
+ "cars__count",
+ "cars__outliers",
+ "region",
+ }
+ assert len(df) == 4
+
+
+def test_boxplot_percentile_incorrect_params():
+ with pytest.raises(QueryObjectValidationError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ )
+
+ with pytest.raises(QueryObjectValidationError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[10],
+ )
+
+ with pytest.raises(QueryObjectValidationError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[90, 10],
+ )
+
+ with pytest.raises(QueryObjectValidationError):
+ boxplot(
+ df=names_df,
+ groupby=["region"],
+ whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
+ metrics=["cars"],
+ percentiles=[10, 90, 10],
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_compare.py b/tests/unit_tests/pandas_postprocessing/test_compare.py
new file mode 100644
index 0000000..d9213ca
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_compare.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from superset.utils.pandas_postprocessing import compare
+from tests.unit_tests.fixtures.dataframes import timeseries_df2
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_compare():
+ # `difference` comparison
+ post_df = compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type="difference",
+ )
+ assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
+ assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
+
+ # drop original columns
+ post_df = compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type="difference",
+ drop_original_columns=True,
+ )
+ assert post_df.columns.tolist() == ["label", "difference__y__z"]
+
+ # `percentage` comparison
+ post_df = compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type="percentage",
+ )
+ assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
+ assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
+
+ # `ratio` comparison
+ post_df = compare(
+ df=timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type="ratio",
+ )
+ assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
+ assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]
diff --git a/tests/unit_tests/pandas_postprocessing/test_contribution.py b/tests/unit_tests/pandas_postprocessing/test_contribution.py
new file mode 100644
index 0000000..78212cb
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_contribution.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+import pytest
+from pandas import DataFrame
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
+from superset.utils.pandas_postprocessing import contribution
+
+
+def test_contribution():
+ df = DataFrame(
+ {
+ DTTM_ALIAS: [datetime(2020, 7, 16, 14, 49), datetime(2020, 7, 16, 14, 50),],
+ "a": [1, 3],
+ "b": [1, 9],
+ }
+ )
+ with pytest.raises(QueryObjectValidationError, match="not numeric"):
+ contribution(df, columns=[DTTM_ALIAS])
+
+ with pytest.raises(QueryObjectValidationError, match="same length"):
+ contribution(df, columns=["a"], rename_columns=["aa", "bb"])
+
+ # cell contribution across row
+ processed_df = contribution(
+ df, orientation=PostProcessingContributionOrientation.ROW,
+ )
+ assert processed_df.columns.tolist() == [DTTM_ALIAS, "a", "b"]
+ assert processed_df["a"].tolist() == [0.5, 0.25]
+ assert processed_df["b"].tolist() == [0.5, 0.75]
+
+ # cell contribution across column without temporal column
+ df.pop(DTTM_ALIAS)
+ processed_df = contribution(
+ df, orientation=PostProcessingContributionOrientation.COLUMN
+ )
+ assert processed_df.columns.tolist() == ["a", "b"]
+ assert processed_df["a"].tolist() == [0.25, 0.75]
+ assert processed_df["b"].tolist() == [0.1, 0.9]
+
+ # contribution only on selected columns
+ processed_df = contribution(
+ df,
+ orientation=PostProcessingContributionOrientation.COLUMN,
+ columns=["a"],
+ rename_columns=["pct_a"],
+ )
+ assert processed_df.columns.tolist() == ["a", "b", "pct_a"]
+ assert processed_df["a"].tolist() == [1, 3]
+ assert processed_df["b"].tolist() == [1, 9]
+ assert processed_df["pct_a"].tolist() == [0.25, 0.75]
diff --git a/tests/unit_tests/pandas_postprocessing/test_cum.py b/tests/unit_tests/pandas_postprocessing/test_cum.py
new file mode 100644
index 0000000..b4b8fad
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_cum.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from pandas import to_datetime
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing import cum, pivot
+from tests.unit_tests.fixtures.dataframes import (
+ multiple_metrics_df,
+ single_metric_df,
+ timeseries_df,
+)
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_cum():
+ # create new column (cumsum)
+ post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
+ assert post_df.columns.tolist() == ["label", "y", "y2"]
+ assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
+ assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
+
+ # overwrite column (cumprod)
+ post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
+
+ # overwrite column (cummin)
+ post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
+
+ # invalid operator
+ with pytest.raises(QueryObjectValidationError):
+ cum(
+ df=timeseries_df, columns={"y": "y"}, operator="abc",
+ )
+
+
+def test_cum_with_pivot_df_and_single_metric():
+ pivot_df = pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ flatten_columns=False,
+ reset_index=False,
+ )
+ cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
+ # dttm UK US
+ # 0 2019-01-01 5 6
+ # 1 2019-01-02 12 14
+ assert cum_df["UK"].to_list() == [5.0, 12.0]
+ assert cum_df["US"].to_list() == [6.0, 14.0]
+ assert (
+ cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
+ )
+
+
+def test_cum_with_pivot_df_and_multiple_metrics():
+ pivot_df = pivot(
+ df=multiple_metrics_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={
+ "sum_metric": {"operator": "sum"},
+ "count_metric": {"operator": "sum"},
+ },
+ flatten_columns=False,
+ reset_index=False,
+ )
+ cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
+ # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
+ # 0 2019-01-01 1 2 5 6
+ # 1 2019-01-02 4 6 12 14
+ assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
+ assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
+ assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
+ assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
+ assert (
+ cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_diff.py b/tests/unit_tests/pandas_postprocessing/test_diff.py
new file mode 100644
index 0000000..abade20
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_diff.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing import diff
+from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_diff():
+ # overwrite column
+ post_df = diff(df=timeseries_df, columns={"y": "y"})
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [None, 1.0, 1.0, 1.0]
+
+ # add column
+ post_df = diff(df=timeseries_df, columns={"y": "y1"})
+ assert post_df.columns.tolist() == ["label", "y", "y1"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
+ assert series_to_list(post_df["y1"]) == [None, 1.0, 1.0, 1.0]
+
+ # look ahead
+ post_df = diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
+ assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
+
+ # invalid column reference
+ with pytest.raises(QueryObjectValidationError):
+ diff(
+ df=timeseries_df, columns={"abc": "abc"},
+ )
+
+ # diff by columns
+ post_df = diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1)
+ assert post_df.columns.tolist() == ["label", "y", "z"]
+ assert series_to_list(post_df["z"]) == [0.0, 2.0, 8.0, 6.0]
diff --git a/tests/unit_tests/pandas_postprocessing/test_geography.py b/tests/unit_tests/pandas_postprocessing/test_geography.py
new file mode 100644
index 0000000..6162f3c
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_geography.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import (
+ geodetic_parse,
+ geohash_decode,
+ geohash_encode,
+)
+from tests.unit_tests.fixtures.dataframes import lonlat_df
+from tests.unit_tests.pandas_postprocessing.utils import round_floats, series_to_list
+
+
+def test_geohash_decode():
+ # decode lon/lat from geohash
+ post_df = geohash_decode(
+ df=lonlat_df[["city", "geohash"]],
+ geohash="geohash",
+ latitude="latitude",
+ longitude="longitude",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geohash", "latitude", "longitude"]
+ )
+ assert round_floats(series_to_list(post_df["longitude"]), 6) == round_floats(
+ series_to_list(lonlat_df["longitude"]), 6
+ )
+ assert round_floats(series_to_list(post_df["latitude"]), 6) == round_floats(
+ series_to_list(lonlat_df["latitude"]), 6
+ )
+
+
+def test_geohash_encode():
+ # encode lon/lat into geohash
+ post_df = geohash_encode(
+ df=lonlat_df[["city", "latitude", "longitude"]],
+ latitude="latitude",
+ longitude="longitude",
+ geohash="geohash",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geohash", "latitude", "longitude"]
+ )
+ assert series_to_list(post_df["geohash"]) == series_to_list(lonlat_df["geohash"])
+
+
+def test_geodetic_parse():
+ # parse geodetic string with altitude into lon/lat/altitude
+ post_df = geodetic_parse(
+ df=lonlat_df[["city", "geodetic"]],
+ geodetic="geodetic",
+ latitude="latitude",
+ longitude="longitude",
+ altitude="altitude",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geodetic", "latitude", "longitude", "altitude"]
+ )
+ assert series_to_list(post_df["longitude"]) == series_to_list(
+ lonlat_df["longitude"]
+ )
+ assert series_to_list(post_df["latitude"]) == series_to_list(lonlat_df["latitude"])
+ assert series_to_list(post_df["altitude"]) == series_to_list(lonlat_df["altitude"])
+
+ # parse geodetic string into lon/lat
+ post_df = geodetic_parse(
+ df=lonlat_df[["city", "geodetic"]],
+ geodetic="geodetic",
+ latitude="latitude",
+ longitude="longitude",
+ )
+ assert sorted(post_df.columns.tolist()) == sorted(
+ ["city", "geodetic", "latitude", "longitude"]
+ )
+ assert series_to_list(post_df["longitude"]) == series_to_list(
+ lonlat_df["longitude"]
+ )
+ assert series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"])
diff --git a/tests/unit_tests/pandas_postprocessing/test_pivot.py b/tests/unit_tests/pandas_postprocessing/test_pivot.py
new file mode 100644
index 0000000..55779e3
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_pivot.py
@@ -0,0 +1,266 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+from pandas import DataFrame, Timestamp, to_datetime
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
+from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
+from tests.unit_tests.pandas_postprocessing.utils import (
+ AGGREGATES_MULTIPLE,
+ AGGREGATES_SINGLE,
+)
+
+
+def test_flatten_column_after_pivot():
+ """
+ Test pivot column flattening function
+ """
+ # single aggregate cases
+ assert (
+ _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",)
+ == "idx_nulls"
+ )
+
+ assert (
+ _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column=1234,)
+ == "1234"
+ )
+
+ assert (
+ _flatten_column_after_pivot(
+ aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"),
+ )
+ == "2020-09-29 00:00:00"
+ )
+
+ assert (
+ _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",)
+ == "idx_nulls"
+ )
+
+ assert (
+ _flatten_column_after_pivot(
+ aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
+ )
+ == "col1"
+ )
+
+ assert (
+ _flatten_column_after_pivot(
+ aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234),
+ )
+ == "col1, 1234"
+ )
+
+ # Multiple aggregate cases
+ assert (
+ _flatten_column_after_pivot(
+ aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
+ )
+ == "idx_nulls, asc_idx, col1"
+ )
+
+ assert (
+ _flatten_column_after_pivot(
+ aggregates=AGGREGATES_MULTIPLE,
+ column=("idx_nulls", "asc_idx", "col1", 1234),
+ )
+ == "idx_nulls, asc_idx, col1, 1234"
+ )
+
+
+def test_pivot_without_columns():
+ """
+ Make sure pivot without columns returns correct DataFrame
+ """
+ df = pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
+ assert df.columns.tolist() == ["name", "idx_nulls"]
+ assert len(df) == 101
+ assert df.sum()[1] == 1050
+
+
+def test_pivot_with_single_column():
+ """
+ Make sure pivot with single column returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ assert df.columns.tolist() == ["name", "cat0", "cat1", "cat2"]
+ assert len(df) == 101
+ assert df.sum()[1] == 315
+
+ df = pivot(
+ df=categories_df,
+ index=["dept"],
+ columns=["category"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ assert df.columns.tolist() == ["dept", "cat0", "cat1", "cat2"]
+ assert len(df) == 5
+
+
+def test_pivot_with_multiple_columns():
+ """
+ Make sure pivot with multiple columns returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category", "dept"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+ assert len(df.columns) == 1 + 3 * 5 # index + possible permutations
+
+
+def test_pivot_fill_values():
+ """
+ Make sure pivot with fill values returns correct DataFrame
+ """
+ df = pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category"],
+ metric_fill_value=1,
+ aggregates={"idx_nulls": {"operator": "sum"}},
+ )
+ assert df.sum()[1] == 382
+
+
+def test_pivot_fill_column_values():
+ """
+ Make sure pivot witn null column names returns correct DataFrame
+ """
+ df_copy = categories_df.copy()
+ df_copy["category"] = None
+ df = pivot(
+ df=df_copy,
+ index=["name"],
+ columns=["category"],
+ aggregates={"idx_nulls": {"operator": "sum"}},
+ )
+ assert len(df) == 101
+ assert df.columns.tolist() == ["name", "<NULL>"]
+
+
+def test_pivot_exceptions():
+ """
+ Make sure pivot raises correct Exceptions
+ """
+ # Missing index
+ with pytest.raises(TypeError):
+ pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
+
+ # invalid index reference
+ with pytest.raises(QueryObjectValidationError):
+ pivot(
+ df=categories_df,
+ index=["abc"],
+ columns=["dept"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+
+ # invalid column reference
+ with pytest.raises(QueryObjectValidationError):
+ pivot(
+ df=categories_df,
+ index=["dept"],
+ columns=["abc"],
+ aggregates=AGGREGATES_SINGLE,
+ )
+
+ # invalid aggregate options
+ with pytest.raises(QueryObjectValidationError):
+ pivot(
+ df=categories_df,
+ index=["name"],
+ columns=["category"],
+ aggregates={"idx_nulls": {}},
+ )
+
+
+def test_pivot_eliminate_cartesian_product_columns():
+ # single metric
+ mock_df = DataFrame(
+ {
+ "dttm": to_datetime(["2019-01-01", "2019-01-01"]),
+ "a": [0, 1],
+ "b": [0, 1],
+ "metric": [9, np.NAN],
+ }
+ )
+
+ df = pivot(
+ df=mock_df,
+ index=["dttm"],
+ columns=["a", "b"],
+ aggregates={"metric": {"operator": "mean"}},
+ drop_missing_columns=False,
+ )
+ assert list(df.columns) == ["dttm", "0, 0", "1, 1"]
+ assert np.isnan(df["1, 1"][0])
+
+ # multiple metrics
+ mock_df = DataFrame(
+ {
+ "dttm": to_datetime(["2019-01-01", "2019-01-01"]),
+ "a": [0, 1],
+ "b": [0, 1],
+ "metric": [9, np.NAN],
+ "metric2": [10, 11],
+ }
+ )
+
+ df = pivot(
+ df=mock_df,
+ index=["dttm"],
+ columns=["a", "b"],
+ aggregates={"metric": {"operator": "mean"}, "metric2": {"operator": "mean"},},
+ drop_missing_columns=False,
+ )
+ assert list(df.columns) == [
+ "dttm",
+ "metric, 0, 0",
+ "metric, 1, 1",
+ "metric2, 0, 0",
+ "metric2, 1, 1",
+ ]
+ assert np.isnan(df["metric, 1, 1"][0])
+
+
+def test_pivot_without_flatten_columns_and_reset_index():
+ df = pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ flatten_columns=False,
+ reset_index=False,
+ )
+ # metric
+ # country UK US
+ # dttm
+ # 2019-01-01 5 6
+ # 2019-01-02 7 8
+ assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
+ assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py b/tests/unit_tests/pandas_postprocessing/test_prophet.py
new file mode 100644
index 0000000..ce5c45b
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime
+from importlib.util import find_spec
+
+import pytest
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.core import DTTM_ALIAS
+from superset.utils.pandas_postprocessing import prophet
+from tests.unit_tests.fixtures.dataframes import prophet_df
+
+
+def test_prophet_valid():
+ pytest.importorskip("prophet")
+
+ df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
+ columns = {column for column in df.columns}
+ assert columns == {
+ DTTM_ALIAS,
+ "a__yhat",
+ "a__yhat_upper",
+ "a__yhat_lower",
+ "a",
+ "b__yhat",
+ "b__yhat_upper",
+ "b__yhat_lower",
+ "b",
+ }
+ assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
+ assert len(df) == 7
+
+ df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9)
+ assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
+ assert len(df) == 9
+
+
+def test_prophet_valid_zero_periods():
+ pytest.importorskip("prophet")
+
+ df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9)
+ columns = {column for column in df.columns}
+ assert columns == {
+ DTTM_ALIAS,
+ "a__yhat",
+ "a__yhat_upper",
+ "a__yhat_lower",
+ "a",
+ "b__yhat",
+ "b__yhat_upper",
+ "b__yhat_lower",
+ "b",
+ }
+ assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
+ assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
+ assert len(df) == 4
+
+
+def test_prophet_import():
+ dynamic_module = find_spec("prophet")
+ if dynamic_module is None:
+ with pytest.raises(QueryObjectValidationError):
+ prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
+
+
+def test_prophet_missing_temporal_column():
+ df = prophet_df.drop(DTTM_ALIAS, axis=1)
+
+ with pytest.raises(QueryObjectValidationError):
+ prophet(
+ df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
+ )
+
+
+def test_prophet_incorrect_confidence_interval():
+ with pytest.raises(QueryObjectValidationError):
+ prophet(
+ df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
+ )
+
+ with pytest.raises(QueryObjectValidationError):
+ prophet(
+ df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
+ )
+
+
+def test_prophet_incorrect_periods():
+ with pytest.raises(QueryObjectValidationError):
+ prophet(
+ df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
+ )
+
+
+def test_prophet_incorrect_time_grain():
+ with pytest.raises(QueryObjectValidationError):
+ prophet(
+ df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_resample.py b/tests/unit_tests/pandas_postprocessing/test_resample.py
new file mode 100644
index 0000000..872f2ed
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_resample.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from pandas import DataFrame, to_datetime
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing import resample
+from tests.unit_tests.fixtures.dataframes import timeseries_df
+
+
+def test_resample():
+ df = timeseries_df.copy()
+ df.index.name = "time_column"
+ df.reset_index(inplace=True)
+
+ post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
+ assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
+
+ assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
+
+ post_df = resample(
+ df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
+ )
+ assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
+ assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
+
+
+def test_resample_with_groupby():
+ """
+The Dataframe contains a timestamp column, a string column and a numeric column.
+__timestamp city val
+0 2022-01-13 Chicago 6.0
+1 2022-01-13 LA 5.0
+2 2022-01-13 NY 4.0
+3 2022-01-11 Chicago 3.0
+4 2022-01-11 LA 2.0
+5 2022-01-11 NY 1.0
+ """
+ df = DataFrame(
+ {
+ "__timestamp": to_datetime(
+ [
+ "2022-01-13",
+ "2022-01-13",
+ "2022-01-13",
+ "2022-01-11",
+ "2022-01-11",
+ "2022-01-11",
+ ]
+ ),
+ "city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"],
+ "val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
+ }
+ )
+ post_df = resample(
+ df=df,
+ rule="1D",
+ method="asfreq",
+ fill_value=0,
+ time_column="__timestamp",
+ groupby_columns=("city",),
+ )
+ assert list(post_df.columns) == [
+ "__timestamp",
+ "city",
+ "val",
+ ]
+ assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
+ ["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
+ )
+ assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
+
+ # should raise error when get a non-existent column
+ with pytest.raises(QueryObjectValidationError):
+ resample(
+ df=df,
+ rule="1D",
+ method="asfreq",
+ fill_value=0,
+ time_column="__timestamp",
+ groupby_columns=("city", "unkonw_column",),
+ )
+
+ # should raise error when get a None value in groupby list
+ with pytest.raises(QueryObjectValidationError):
+ resample(
+ df=df,
+ rule="1D",
+ method="asfreq",
+ fill_value=0,
+ time_column="__timestamp",
+ groupby_columns=("city", None,),
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_rolling.py b/tests/unit_tests/pandas_postprocessing/test_rolling.py
new file mode 100644
index 0000000..227b03a
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_rolling.py
@@ -0,0 +1,147 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from pandas import to_datetime
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing import pivot, rolling
+from tests.unit_tests.fixtures.dataframes import (
+ multiple_metrics_df,
+ single_metric_df,
+ timeseries_df,
+)
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_rolling():
+ # sum rolling type
+ post_df = rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
+
+ # mean rolling type with alias
+ post_df = rolling(
+ df=timeseries_df,
+ rolling_type="mean",
+ columns={"y": "y_mean"},
+ window=10,
+ min_periods=0,
+ )
+ assert post_df.columns.tolist() == ["label", "y", "y_mean"]
+ assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
+
+ # count rolling type
+ post_df = rolling(
+ df=timeseries_df,
+ rolling_type="count",
+ columns={"y": "y"},
+ window=10,
+ min_periods=0,
+ )
+ assert post_df.columns.tolist() == ["label", "y"]
+ assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
+
+ # quantile rolling type
+ post_df = rolling(
+ df=timeseries_df,
+ columns={"y": "q1"},
+ rolling_type="quantile",
+ rolling_type_options={"quantile": 0.25},
+ window=10,
+ min_periods=0,
+ )
+ assert post_df.columns.tolist() == ["label", "y", "q1"]
+ assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
+
+ # incorrect rolling type
+ with pytest.raises(QueryObjectValidationError):
+ rolling(
+ df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
+ )
+
+ # incorrect rolling type options
+ with pytest.raises(QueryObjectValidationError):
+ rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="quantile",
+ rolling_type_options={"abc": 123},
+ window=2,
+ )
+
+
+def test_rolling_with_pivot_df_and_single_metric():
+ pivot_df = pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ flatten_columns=False,
+ reset_index=False,
+ )
+ rolling_df = rolling(
+ df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
+ )
+ # dttm UK US
+ # 0 2019-01-01 5 6
+ # 1 2019-01-02 12 14
+ assert rolling_df["UK"].to_list() == [5.0, 12.0]
+ assert rolling_df["US"].to_list() == [6.0, 14.0]
+ assert (
+ rolling_df["dttm"].to_list()
+ == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
+ )
+
+ rolling_df = rolling(
+ df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
+ )
+ assert rolling_df.empty is True
+
+
+def test_rolling_with_pivot_df_and_multiple_metrics():
+ pivot_df = pivot(
+ df=multiple_metrics_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={
+ "sum_metric": {"operator": "sum"},
+ "count_metric": {"operator": "sum"},
+ },
+ flatten_columns=False,
+ reset_index=False,
+ )
+ rolling_df = rolling(
+ df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
+ )
+ # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
+ # 0 2019-01-01 1.0 2.0 5.0 6.0
+ # 1 2019-01-02 4.0 6.0 12.0 14.0
+ assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
+ assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
+ assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
+ assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
+ assert (
+ rolling_df["dttm"].to_list()
+ == to_datetime(["2019-01-01", "2019-01-02",]).to_list()
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_select.py b/tests/unit_tests/pandas_postprocessing/test_select.py
new file mode 100644
index 0000000..aac644d
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_select.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing.select import select
+from tests.unit_tests.fixtures.dataframes import timeseries_df
+
+
+def test_select():
+ # reorder columns
+ post_df = select(df=timeseries_df, columns=["y", "label"])
+ assert post_df.columns.tolist() == ["y", "label"]
+
+ # one column
+ post_df = select(df=timeseries_df, columns=["label"])
+ assert post_df.columns.tolist() == ["label"]
+
+ # rename and select one column
+ post_df = select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
+ assert post_df.columns.tolist() == ["y1"]
+
+ # rename one and leave one unchanged
+ post_df = select(df=timeseries_df, rename={"y": "y1"})
+ assert post_df.columns.tolist() == ["label", "y1"]
+
+ # drop one column
+ post_df = select(df=timeseries_df, exclude=["label"])
+ assert post_df.columns.tolist() == ["y"]
+
+ # rename and drop one column
+ post_df = select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
+ assert post_df.columns.tolist() == ["y1"]
+
+ # invalid columns
+ with pytest.raises(QueryObjectValidationError):
+ select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
+
+ # select renamed column by new name
+ with pytest.raises(QueryObjectValidationError):
+ select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})
diff --git a/tests/unit_tests/pandas_postprocessing/test_sort.py b/tests/unit_tests/pandas_postprocessing/test_sort.py
new file mode 100644
index 0000000..43daa9c
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_sort.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from superset.exceptions import QueryObjectValidationError
+from superset.utils.pandas_postprocessing import sort
+from tests.unit_tests.fixtures.dataframes import categories_df
+from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+
+
+def test_sort():
+ df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
+ assert series_to_list(df["asc_idx"])[1] == 96
+
+ with pytest.raises(QueryObjectValidationError):
+ sort(df=df, columns={"abc": True})
diff --git a/tests/unit_tests/pandas_postprocessing/utils.py b/tests/unit_tests/pandas_postprocessing/utils.py
new file mode 100644
index 0000000..07366b1
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/utils.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import math
+from typing import Any, List, Optional
+
+from pandas import Series
+
+AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
+AGGREGATES_MULTIPLE = {
+ "idx_nulls": {"operator": "sum"},
+ "asc_idx": {"operator": "mean"},
+}
+
+
+def series_to_list(series: Series) -> List[Any]:
+ """
+ Converts a `Series` to a regular list, and replaces non-numeric values to
+ Nones.
+
+ :param series: Series to convert
+ :return: list without nan or inf
+ """
+ return [
+ None
+ if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
+ else val
+ for val in series.tolist()
+ ]
+
+
+def round_floats(
+ floats: List[Optional[float]], precision: int
+) -> List[Optional[float]]:
+ """
+ Round list of floats to certain precision
+
+ :param floats: floats to round
+ :param precision: intended decimal precision
+ :return: rounded floats
+ """
+ return [round(val, precision) if val else None for val in floats]