You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/09/13 06:44:54 UTC

[spark] branch master updated: [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods `

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ae08787f5c5 [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods `
ae08787f5c5 is described below

commit ae08787f5c50e485ef4432a0c2da8b3b7290d725
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Sep 13 14:44:18 2022 +0800

    [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods `
    
    ### What changes were proposed in this pull request?
    refactor `pearson` correlation in `DataFrame.corr` to:
    
    1. support missing values;
    2. add parameter  `min_periods`;
    3. enable arrow execution since no longer depend on `VectorUDT`;
    4. support lazy evaluation;
    
    before
    ```
    In [1]: import pyspark.pandas as ps
    
    In [2]: df = ps.DataFrame([[1,2], [3,None]])
    
    In [3]: df
    
       0    1
    0  1  2.0
    1  3  NaN
    
    In [4]: df.corr()
    22/09/09 16:53:18 ERROR Executor: Exception in task 9.0 in stage 5.0 (TID 24)
    org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (VectorAssembler$$Lambda$2660/0x0000000801215840: (struct<0_double_VectorAssembler_0915f96ec689:double,1:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
    ```
    
    after
    ```
    In [1]: import pyspark.pandas as ps
    
    In [2]: df = ps.DataFrame([[1,2], [3,None]])
    
    In [3]: df.corr()
    
         0   1
    0  1.0 NaN
    1  NaN NaN
    
    In [4]: df.to_pandas().corr()
    /Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:976: PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's memory. It should only be used if the resulting pandas DataFrame is expected to be small.
      warnings.warn(message, PandasAPIOnSparkAdviceWarning)
    Out[4]:
         0   1
    0  1.0 NaN
    1  NaN NaN
    ```
    
    ### Why are the changes needed?
    for API coverage and support common cases containing missing values
    
    ### Does this PR introduce _any_ user-facing change?
    yes, API change, new parameter supported
    
    ### How was this patch tested?
    added UT
    
    Closes #37845 from zhengruifeng/ps_df_corr_missing_value.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/pandas/frame.py            | 209 +++++++++++++++++++++++++++++-
 python/pyspark/pandas/tests/test_stats.py |  34 +++++
 2 files changed, 238 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 3438d07896e..cf14a548266 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -1417,15 +1417,23 @@ class DataFrame(Frame, Generic[T]):
 
     agg = aggregate
 
-    def corr(self, method: str = "pearson") -> "DataFrame":
+    def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "DataFrame":
         """
         Compute pairwise correlation of columns, excluding NA/null values.
 
+        .. versionadded:: 3.3.0
+
         Parameters
         ----------
         method : {'pearson', 'spearman'}
             * pearson : standard correlation coefficient
             * spearman : Spearman rank correlation
+        min_periods : int, optional
+            Minimum number of observations required per pair of columns
+            to have a valid result. Currently only available for Pearson
+            correlation.
+
+            .. versionadded:: 3.4.0
 
         Returns
         -------
@@ -1454,11 +1462,202 @@ class DataFrame(Frame, Generic[T]):
         There are behavior differences between pandas-on-Spark and pandas.
 
         * the `method` argument only accepts 'pearson', 'spearman'
-        * the data should not contain NaNs. pandas-on-Spark will return an error.
-        * pandas-on-Spark doesn't support the following argument(s).
+        * if the `method` is `spearman`, the data should not contain NaNs.
+        * if the `method` is `spearman`, `min_periods` argument is not supported.
+        """
+        if method not in ["pearson", "spearman", "kendall"]:
+            raise ValueError(f"Invalid method {method}")
+        if method == "kendall":
+            raise NotImplementedError("method doesn't support kendall for now")
+        if min_periods is not None and not isinstance(min_periods, int):
+            raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}")
+        if min_periods is not None and method == "spearman":
+            raise NotImplementedError("min_periods doesn't support spearman for now")
+
+        if method == "pearson":
+            min_periods = 1 if min_periods is None else min_periods
+            internal = self._internal.resolved_copy
+            numeric_labels = [
+                label
+                for label in internal.column_labels
+                if isinstance(internal.spark_type_for(label), (NumericType, BooleanType))
+            ]
+            numeric_scols: List[Column] = [
+                internal.spark_column_for(label).cast("double") for label in numeric_labels
+            ]
+            numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels]
+            num_scols = len(numeric_scols)
+
+            sdf = internal.spark_frame
+            tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__")
+            tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__")
+            tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__")
+            tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__")
+
+            # simple dataset
+            # +---+---+----+
+            # |  A|  B|   C|
+            # +---+---+----+
+            # |  1|  2| 3.0|
+            # |  4|  1|null|
+            # +---+---+----+
+
+            pair_scols: List[Column] = []
+            for i in range(0, num_scols):
+                for j in range(i, num_scols):
+                    pair_scols.append(
+                        F.struct(
+                            F.lit(i).alias(tmp_index_1_col_name),
+                            F.lit(j).alias(tmp_index_2_col_name),
+                            numeric_scols[i].alias(tmp_value_1_col_name),
+                            numeric_scols[j].alias(tmp_value_2_col_name),
+                        )
+                    )
+
+            # +-------------------+-------------------+-------------------+-------------------+
+            # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__|
+            # +-------------------+-------------------+-------------------+-------------------+
+            # |                  0|                  0|                1.0|                1.0|
+            # |                  0|                  1|                1.0|                2.0|
+            # |                  0|                  2|                1.0|                3.0|
+            # |                  1|                  1|                2.0|                2.0|
+            # |                  1|                  2|                2.0|                3.0|
+            # |                  2|                  2|                3.0|                3.0|
+            # |                  0|                  0|                4.0|                4.0|
+            # |                  0|                  1|                4.0|                1.0|
+            # |                  0|                  2|                4.0|               null|
+            # |                  1|                  1|                1.0|                1.0|
+            # |                  1|                  2|                1.0|               null|
+            # |                  2|                  2|               null|               null|
+            # +-------------------+-------------------+-------------------+-------------------+
+            tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__")
+            sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select(
+                F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name),
+                F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name),
+                F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name),
+                F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name),
+            )
+
+            # +-------------------+-------------------+------------------------+-----------------+
+            # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__|
+            # +-------------------+-------------------+------------------------+-----------------+
+            # |                  2|                  2|                    null|                1|
+            # |                  1|                  2|                    null|                1|
+            # |                  1|                  1|                     1.0|                2|
+            # |                  0|                  0|                     1.0|                2|
+            # |                  0|                  1|                    -1.0|                2|
+            # |                  0|                  2|                    null|                1|
+            # +-------------------+-------------------+------------------------+-----------------+
+            tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__")
+            tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__")
+            sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg(
+                F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name),
+                F.count(
+                    F.when(
+                        F.col(tmp_value_1_col_name).isNotNull()
+                        & F.col(tmp_value_2_col_name).isNotNull(),
+                        1,
+                    )
+                ).alias(tmp_count_col_name),
+            )
+
+            # +-------------------+-------------------+------------------------+
+            # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|
+            # +-------------------+-------------------+------------------------+
+            # |                  2|                  2|                    null|
+            # |                  1|                  2|                    null|
+            # |                  2|                  1|                    null|
+            # |                  1|                  1|                     1.0|
+            # |                  0|                  0|                     1.0|
+            # |                  0|                  1|                    -1.0|
+            # |                  1|                  0|                    -1.0|
+            # |                  0|                  2|                    null|
+            # |                  2|                  0|                    null|
+            # +-------------------+-------------------+------------------------+
+            sdf = (
+                sdf.withColumn(
+                    tmp_corr_col_name,
+                    F.when(
+                        F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name)
+                    ).otherwise(F.lit(None)),
+                )
+                .withColumn(
+                    tmp_tuple_col_name,
+                    F.explode(
+                        F.when(
+                            F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name),
+                            F.lit([0]),
+                        ).otherwise(F.lit([0, 1]))
+                    ),
+                )
+                .select(
+                    F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name))
+                    .otherwise(F.col(tmp_index_2_col_name))
+                    .alias(tmp_index_1_col_name),
+                    F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name))
+                    .otherwise(F.col(tmp_index_1_col_name))
+                    .alias(tmp_index_2_col_name),
+                    F.col(tmp_corr_col_name),
+                )
+            )
+
+            # +-------------------+--------------------+
+            # |__tmp_index_1_col__|   __tmp_array_col__|
+            # +-------------------+--------------------+
+            # |                  0|[{0, 1.0}, {1, -1...|
+            # |                  1|[{0, -1.0}, {1, 1...|
+            # |                  2|[{0, null}, {1, n...|
+            # +-------------------+--------------------+
+            tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__")
+            sdf = (
+                sdf.groupby(tmp_index_1_col_name)
+                .agg(
+                    F.array_sort(
+                        F.collect_list(
+                            F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name))
+                        )
+                    ).alias(tmp_array_col_name)
+                )
+                .orderBy(tmp_index_1_col_name)
+            )
+
+            for i in range(0, num_scols):
+                sdf = sdf.withColumn(
+                    tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i)
+                ).withColumn(
+                    numeric_col_names[i],
+                    F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"),
+                )
+
+            index_col_names: List[str] = []
+            if internal.column_labels_level > 1:
+                for level in range(0, internal.column_labels_level):
+                    index_col_name = SPARK_INDEX_NAME_FORMAT(level)
+                    indices = [label[level] for label in numeric_labels]
+                    sdf = sdf.withColumn(
+                        index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name))
+                    )
+                    index_col_names.append(index_col_name)
+            else:
+                sdf = sdf.withColumn(
+                    SPARK_DEFAULT_INDEX_NAME,
+                    F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)),
+                )
+                index_col_names = [SPARK_DEFAULT_INDEX_NAME]
+
+            sdf = sdf.select(*index_col_names, *numeric_col_names)
+
+            return DataFrame(
+                InternalFrame(
+                    spark_frame=sdf,
+                    index_spark_columns=[
+                        scol_for(sdf, index_col_name) for index_col_name in index_col_names
+                    ],
+                    column_labels=numeric_labels,
+                    column_label_names=internal.column_label_names,
+                )
+            )
 
-          * `min_periods` argument is not supported
-        """
         return cast(DataFrame, ps.from_pandas(corr(self, method)))
 
     # TODO: add axis parameter and support more methods
diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py
index e8f5048033b..7e2ca96e60f 100644
--- a/python/pyspark/pandas/tests/test_stats.py
+++ b/python/pyspark/pandas/tests/test_stats.py
@@ -257,6 +257,40 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
         self.assert_eq(psdf.skew(), pdf.skew(), almost=True)
         self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True)
 
+    def test_dataframe_corr(self):
+        # existing 'test_corr' is mixed by df.corr and ser.corr, will delete 'test_corr'
+        # when we have separate tests for df.corr and ser.corr
+        pdf = makeMissingDataframe(0.3, 42)
+        psdf = ps.from_pandas(pdf)
+
+        with self.assertRaisesRegex(ValueError, "Invalid method"):
+            psdf.corr("std")
+        with self.assertRaisesRegex(NotImplementedError, "kendall for now"):
+            psdf.corr("kendall")
+        with self.assertRaisesRegex(TypeError, "Invalid min_periods type"):
+            psdf.corr(min_periods="3")
+        with self.assertRaisesRegex(NotImplementedError, "spearman for now"):
+            psdf.corr(method="spearman", min_periods=3)
+
+        self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False)
+        self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False)
+        self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False)
+        self.assert_eq(
+            (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False
+        )
+
+        # multi-index columns
+        columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")])
+        pdf.columns = columns
+        psdf.columns = columns
+
+        self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False)
+        self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False)
+        self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False)
+        self.assert_eq(
+            (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False
+        )
+
     def test_corr(self):
         # Disable arrow execution since corr() is using UDT internally which is not supported.
         with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org