You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/14 07:53:51 UTC
[spark] branch master updated: [SPARK-43645][SPARK-43622][PS][CONNECT] Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ac7d7d835b0 [SPARK-43645][SPARK-43622][PS][CONNECT] Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
ac7d7d835b0 is described below
commit ac7d7d835b0cc8c4f5268d7588837336df4d5487
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Jun 14 15:52:55 2023 +0800
[SPARK-43645][SPARK-43622][PS][CONNECT] Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
### What changes were proposed in this pull request?
Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled UTs
Closes #41589 from zhengruifeng/ps_con_var.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 10 ++++++++
python/pyspark/pandas/spark/functions.py | 30 +++++++++++++++++++---
.../tests/connect/frame/test_parity_constructor.py | 9 +------
.../connect/groupby/test_parity_missing_data.py | 6 -----
.../tests/connect/groupby/test_parity_stat.py | 13 ++--------
.../tests/connect/test_parity_generic_functions.py | 4 ++-
.../pandas/tests/connect/test_parity_stats.py | 16 ++++--------
7 files changed, 47 insertions(+), 41 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4d55639f876..e226e4be18f 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1639,6 +1639,16 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging {
val dropna = extractBoolean(children(1), "dropna")
Some(aggregate.PandasProduct(children(0), dropna).toAggregateExpression(false))
+ case "pandas_stddev" if fun.getArgumentsCount == 2 =>
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ val ddof = extractInteger(children(1), "ddof")
+ Some(aggregate.PandasStddev(children(0), ddof).toAggregateExpression(false))
+
+ case "pandas_var" if fun.getArgumentsCount == 2 =>
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ val ddof = extractInteger(children(1), "ddof")
+ Some(aggregate.PandasVariance(children(0), ddof).toAggregateExpression(false))
+
case "pandas_covar" if fun.getArgumentsCount == 3 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val ddof = extractInteger(children(2), "ddof")
diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py
index 2df51d525ce..739671821fe 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -47,13 +47,35 @@ def product(col: Column, dropna: bool) -> Column:
def stddev(col: Column, ddof: int) -> Column:
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
+ if is_remote():
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+ return _invoke_function_over_columns( # type: ignore[return-value]
+ "pandas_stddev",
+ col, # type: ignore[arg-type]
+ lit(ddof),
+ )
+
+ else:
+
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
def var(col: Column, ddof: int) -> Column:
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
+ if is_remote():
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+ return _invoke_function_over_columns( # type: ignore[return-value]
+ "pandas_var",
+ col, # type: ignore[arg-type]
+ lit(ddof),
+ )
+
+ else:
+
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
def skew(col: Column) -> Column:
diff --git a/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py b/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py
index e53de1b0720..19af61c0cef 100644
--- a/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py
+++ b/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py
@@ -16,7 +16,6 @@
#
import unittest
-from pyspark import pandas as ps
from pyspark.pandas.tests.frame.test_constructor import FrameConstructorMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
@@ -25,13 +24,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils
class FrameParityConstructorTests(
FrameConstructorMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
):
- @property
- def psdf(self):
- return ps.from_pandas(self.pdf)
-
- @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
- def test_dataframe(self):
- super().test_dataframe()
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py
index e6f91b0f17f..1ca101ef545 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py
@@ -42,12 +42,6 @@ class GroupbyParityMissingDataTests(
def test_fillna(self):
super().test_fillna()
- @unittest.skip(
- "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
- )
- def test_dropna(self):
- super().test_dropna()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.groupby.test_parity_missing_data import * # noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
index c3fe251e616..75f1ed41d61 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
@@ -22,17 +22,8 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils
class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
- @unittest.skip(
- "TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.covar in Spark Connect."
- )
- def test_basic_stat_funcs(self):
- super().test_basic_stat_funcs()
-
- @unittest.skip(
- "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
- )
- def test_ddof(self):
- super().test_ddof()
+
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
index 76a5c8d30ea..af165eee8f5 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
@@ -28,7 +28,9 @@ class GenericFunctionsParityTests(
def test_interpolate(self):
super().test_interpolate()
- @unittest.skip("TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.std in Spark Connect.")
+ @unittest.skip(
+ "TODO(SPARK-43627): Enable pyspark.pandas.spark.functions.skew in Spark Connect."
+ )
def test_stat_functions(self):
super().test_stat_functions()
diff --git a/python/pyspark/pandas/tests/connect/test_parity_stats.py b/python/pyspark/pandas/tests/connect/test_parity_stats.py
index ae4597c3638..f72eec9dba4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_stats.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_stats.py
@@ -35,26 +35,20 @@ class StatsParityTests(StatsTestsMixin, PandasOnSparkTestUtils, ReusedConnectTes
super().test_skew_kurt_numerical_stability()
@unittest.skip(
- "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
+ "TODO(SPARK-43627): Enable pyspark.pandas.spark.functions.skew in Spark Connect."
)
def test_stat_functions(self):
super().test_stat_functions()
@unittest.skip(
- "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
+ "TODO(SPARK-43627): Enable pyspark.pandas.spark.functions.skew in Spark Connect."
)
def test_stat_functions_multiindex_column(self):
super().test_stat_functions_multiindex_column()
- @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
- def test_stats_on_boolean_dataframe(self):
- super().test_stats_on_boolean_dataframe()
-
- @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
- def test_stats_on_boolean_series(self):
- super().test_stats_on_boolean_series()
-
- @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
+ @unittest.skip(
+ "TODO(SPARK-43626): Enable pyspark.pandas.spark.functions.kurt in Spark Connect."
+ )
def test_stats_on_non_numeric_columns_should_be_discarded_if_numeric_only_is_true(self):
super().test_stats_on_non_numeric_columns_should_be_discarded_if_numeric_only_is_true()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org