You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/09/13 06:44:54 UTC
[spark] branch master updated: [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods `
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ae08787f5c5 [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods `
ae08787f5c5 is described below
commit ae08787f5c50e485ef4432a0c2da8b3b7290d725
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Sep 13 14:44:18 2022 +0800
[SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods `
### What changes were proposed in this pull request?
refactor `pearson` correlation in `DataFrame.corr` to:
1. support missing values;
2. add parameter `min_periods`;
3. enable arrow execution since no longer depend on `VectorUDT`;
4. support lazy evaluation;
before
```
In [1]: import pyspark.pandas as ps
In [2]: df = ps.DataFrame([[1,2], [3,None]])
In [3]: df
0 1
0 1 2.0
1 3 NaN
In [4]: df.corr()
22/09/09 16:53:18 ERROR Executor: Exception in task 9.0 in stage 5.0 (TID 24)
org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (VectorAssembler$$Lambda$2660/0x0000000801215840: (struct<0_double_VectorAssembler_0915f96ec689:double,1:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
```
after
```
In [1]: import pyspark.pandas as ps
In [2]: df = ps.DataFrame([[1,2], [3,None]])
In [3]: df.corr()
0 1
0 1.0 NaN
1 NaN NaN
In [4]: df.to_pandas().corr()
/Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:976: PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's memory. It should only be used if the resulting pandas DataFrame is expected to be small.
warnings.warn(message, PandasAPIOnSparkAdviceWarning)
Out[4]:
0 1
0 1.0 NaN
1 NaN NaN
```
### Why are the changes needed?
for API coverage and support common cases containing missing values
### Does this PR introduce _any_ user-facing change?
yes, API change, new parameter supported
### How was this patch tested?
added UT
Closes #37845 from zhengruifeng/ps_df_corr_missing_value.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/pandas/frame.py | 209 +++++++++++++++++++++++++++++-
python/pyspark/pandas/tests/test_stats.py | 34 +++++
2 files changed, 238 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 3438d07896e..cf14a548266 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -1417,15 +1417,23 @@ class DataFrame(Frame, Generic[T]):
agg = aggregate
- def corr(self, method: str = "pearson") -> "DataFrame":
+ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "DataFrame":
"""
Compute pairwise correlation of columns, excluding NA/null values.
+ .. versionadded:: 3.3.0
+
Parameters
----------
method : {'pearson', 'spearman'}
* pearson : standard correlation coefficient
* spearman : Spearman rank correlation
+ min_periods : int, optional
+ Minimum number of observations required per pair of columns
+ to have a valid result. Currently only available for Pearson
+ correlation.
+
+ .. versionadded:: 3.4.0
Returns
-------
@@ -1454,11 +1462,202 @@ class DataFrame(Frame, Generic[T]):
There are behavior differences between pandas-on-Spark and pandas.
* the `method` argument only accepts 'pearson', 'spearman'
- * the data should not contain NaNs. pandas-on-Spark will return an error.
- * pandas-on-Spark doesn't support the following argument(s).
+ * if the `method` is `spearman`, the data should not contain NaNs.
+ * if the `method` is `spearman`, `min_periods` argument is not supported.
+ """
+ if method not in ["pearson", "spearman", "kendall"]:
+ raise ValueError(f"Invalid method {method}")
+ if method == "kendall":
+ raise NotImplementedError("method doesn't support kendall for now")
+ if min_periods is not None and not isinstance(min_periods, int):
+ raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}")
+ if min_periods is not None and method == "spearman":
+ raise NotImplementedError("min_periods doesn't support spearman for now")
+
+ if method == "pearson":
+ min_periods = 1 if min_periods is None else min_periods
+ internal = self._internal.resolved_copy
+ numeric_labels = [
+ label
+ for label in internal.column_labels
+ if isinstance(internal.spark_type_for(label), (NumericType, BooleanType))
+ ]
+ numeric_scols: List[Column] = [
+ internal.spark_column_for(label).cast("double") for label in numeric_labels
+ ]
+ numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels]
+ num_scols = len(numeric_scols)
+
+ sdf = internal.spark_frame
+ tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__")
+ tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__")
+ tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__")
+ tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__")
+
+ # simple dataset
+ # +---+---+----+
+ # | A| B| C|
+ # +---+---+----+
+ # | 1| 2| 3.0|
+ # | 4| 1|null|
+ # +---+---+----+
+
+ pair_scols: List[Column] = []
+ for i in range(0, num_scols):
+ for j in range(i, num_scols):
+ pair_scols.append(
+ F.struct(
+ F.lit(i).alias(tmp_index_1_col_name),
+ F.lit(j).alias(tmp_index_2_col_name),
+ numeric_scols[i].alias(tmp_value_1_col_name),
+ numeric_scols[j].alias(tmp_value_2_col_name),
+ )
+ )
+
+ # +-------------------+-------------------+-------------------+-------------------+
+ # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__|
+ # +-------------------+-------------------+-------------------+-------------------+
+ # | 0| 0| 1.0| 1.0|
+ # | 0| 1| 1.0| 2.0|
+ # | 0| 2| 1.0| 3.0|
+ # | 1| 1| 2.0| 2.0|
+ # | 1| 2| 2.0| 3.0|
+ # | 2| 2| 3.0| 3.0|
+ # | 0| 0| 4.0| 4.0|
+ # | 0| 1| 4.0| 1.0|
+ # | 0| 2| 4.0| null|
+ # | 1| 1| 1.0| 1.0|
+ # | 1| 2| 1.0| null|
+ # | 2| 2| null| null|
+ # +-------------------+-------------------+-------------------+-------------------+
+ tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__")
+ sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select(
+ F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name),
+ F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name),
+ F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name),
+ F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name),
+ )
+
+ # +-------------------+-------------------+------------------------+-----------------+
+ # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__|
+ # +-------------------+-------------------+------------------------+-----------------+
+ # | 2| 2| null| 1|
+ # | 1| 2| null| 1|
+ # | 1| 1| 1.0| 2|
+ # | 0| 0| 1.0| 2|
+ # | 0| 1| -1.0| 2|
+ # | 0| 2| null| 1|
+ # +-------------------+-------------------+------------------------+-----------------+
+ tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__")
+ tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__")
+ sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg(
+ F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name),
+ F.count(
+ F.when(
+ F.col(tmp_value_1_col_name).isNotNull()
+ & F.col(tmp_value_2_col_name).isNotNull(),
+ 1,
+ )
+ ).alias(tmp_count_col_name),
+ )
+
+ # +-------------------+-------------------+------------------------+
+ # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|
+ # +-------------------+-------------------+------------------------+
+ # | 2| 2| null|
+ # | 1| 2| null|
+ # | 2| 1| null|
+ # | 1| 1| 1.0|
+ # | 0| 0| 1.0|
+ # | 0| 1| -1.0|
+ # | 1| 0| -1.0|
+ # | 0| 2| null|
+ # | 2| 0| null|
+ # +-------------------+-------------------+------------------------+
+ sdf = (
+ sdf.withColumn(
+ tmp_corr_col_name,
+ F.when(
+ F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name)
+ ).otherwise(F.lit(None)),
+ )
+ .withColumn(
+ tmp_tuple_col_name,
+ F.explode(
+ F.when(
+ F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name),
+ F.lit([0]),
+ ).otherwise(F.lit([0, 1]))
+ ),
+ )
+ .select(
+ F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name))
+ .otherwise(F.col(tmp_index_2_col_name))
+ .alias(tmp_index_1_col_name),
+ F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name))
+ .otherwise(F.col(tmp_index_1_col_name))
+ .alias(tmp_index_2_col_name),
+ F.col(tmp_corr_col_name),
+ )
+ )
+
+ # +-------------------+--------------------+
+ # |__tmp_index_1_col__| __tmp_array_col__|
+ # +-------------------+--------------------+
+ # | 0|[{0, 1.0}, {1, -1...|
+ # | 1|[{0, -1.0}, {1, 1...|
+ # | 2|[{0, null}, {1, n...|
+ # +-------------------+--------------------+
+ tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__")
+ sdf = (
+ sdf.groupby(tmp_index_1_col_name)
+ .agg(
+ F.array_sort(
+ F.collect_list(
+ F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name))
+ )
+ ).alias(tmp_array_col_name)
+ )
+ .orderBy(tmp_index_1_col_name)
+ )
+
+ for i in range(0, num_scols):
+ sdf = sdf.withColumn(
+ tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i)
+ ).withColumn(
+ numeric_col_names[i],
+ F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"),
+ )
+
+ index_col_names: List[str] = []
+ if internal.column_labels_level > 1:
+ for level in range(0, internal.column_labels_level):
+ index_col_name = SPARK_INDEX_NAME_FORMAT(level)
+ indices = [label[level] for label in numeric_labels]
+ sdf = sdf.withColumn(
+ index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name))
+ )
+ index_col_names.append(index_col_name)
+ else:
+ sdf = sdf.withColumn(
+ SPARK_DEFAULT_INDEX_NAME,
+ F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)),
+ )
+ index_col_names = [SPARK_DEFAULT_INDEX_NAME]
+
+ sdf = sdf.select(*index_col_names, *numeric_col_names)
+
+ return DataFrame(
+ InternalFrame(
+ spark_frame=sdf,
+ index_spark_columns=[
+ scol_for(sdf, index_col_name) for index_col_name in index_col_names
+ ],
+ column_labels=numeric_labels,
+ column_label_names=internal.column_label_names,
+ )
+ )
- * `min_periods` argument is not supported
- """
return cast(DataFrame, ps.from_pandas(corr(self, method)))
# TODO: add axis parameter and support more methods
diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py
index e8f5048033b..7e2ca96e60f 100644
--- a/python/pyspark/pandas/tests/test_stats.py
+++ b/python/pyspark/pandas/tests/test_stats.py
@@ -257,6 +257,40 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(psdf.skew(), pdf.skew(), almost=True)
self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True)
+ def test_dataframe_corr(self):
+ # existing 'test_corr' is mixed by df.corr and ser.corr, will delete 'test_corr'
+ # when we have separate tests for df.corr and ser.corr
+ pdf = makeMissingDataframe(0.3, 42)
+ psdf = ps.from_pandas(pdf)
+
+ with self.assertRaisesRegex(ValueError, "Invalid method"):
+ psdf.corr("std")
+ with self.assertRaisesRegex(NotImplementedError, "kendall for now"):
+ psdf.corr("kendall")
+ with self.assertRaisesRegex(TypeError, "Invalid min_periods type"):
+ psdf.corr(min_periods="3")
+ with self.assertRaisesRegex(NotImplementedError, "spearman for now"):
+ psdf.corr(method="spearman", min_periods=3)
+
+ self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False)
+ self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False)
+ self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False)
+ self.assert_eq(
+ (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False
+ )
+
+ # multi-index columns
+ columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")])
+ pdf.columns = columns
+ psdf.columns = columns
+
+ self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False)
+ self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False)
+ self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False)
+ self.assert_eq(
+ (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False
+ )
+
def test_corr(self):
# Disable arrow execution since corr() is using UDT internally which is not supported.
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org