You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/14 07:53:51 UTC

[spark] branch master updated: [SPARK-43645][SPARK-43622][PS][CONNECT] Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ac7d7d835b0 [SPARK-43645][SPARK-43622][PS][CONNECT] Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
ac7d7d835b0 is described below

commit ac7d7d835b0cc8c4f5268d7588837336df4d5487
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Jun 14 15:52:55 2023 +0800

    [SPARK-43645][SPARK-43622][PS][CONNECT] Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
    
    ### What changes were proposed in this pull request?
    Enable `pyspark.pandas.spark.functions.{var, stddev}` in Spark Connect
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled UTs
    
    Closes #41589 from zhengruifeng/ps_con_var.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 10 ++++++++
 python/pyspark/pandas/spark/functions.py           | 30 +++++++++++++++++++---
 .../tests/connect/frame/test_parity_constructor.py |  9 +------
 .../connect/groupby/test_parity_missing_data.py    |  6 -----
 .../tests/connect/groupby/test_parity_stat.py      | 13 ++--------
 .../tests/connect/test_parity_generic_functions.py |  4 ++-
 .../pandas/tests/connect/test_parity_stats.py      | 16 ++++--------
 7 files changed, 47 insertions(+), 41 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4d55639f876..e226e4be18f 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1639,6 +1639,16 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging {
         val dropna = extractBoolean(children(1), "dropna")
         Some(aggregate.PandasProduct(children(0), dropna).toAggregateExpression(false))
 
+      case "pandas_stddev" if fun.getArgumentsCount == 2 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val ddof = extractInteger(children(1), "ddof")
+        Some(aggregate.PandasStddev(children(0), ddof).toAggregateExpression(false))
+
+      case "pandas_var" if fun.getArgumentsCount == 2 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val ddof = extractInteger(children(1), "ddof")
+        Some(aggregate.PandasVariance(children(0), ddof).toAggregateExpression(false))
+
       case "pandas_covar" if fun.getArgumentsCount == 3 =>
         val children = fun.getArgumentsList.asScala.map(transformExpression)
         val ddof = extractInteger(children(2), "ddof")
diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py
index 2df51d525ce..739671821fe 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -47,13 +47,35 @@ def product(col: Column, dropna: bool) -> Column:
 
 
 def stddev(col: Column, ddof: int) -> Column:
-    sc = SparkContext._active_spark_context
-    return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
+    if is_remote():
+        from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+        return _invoke_function_over_columns(  # type: ignore[return-value]
+            "pandas_stddev",
+            col,  # type: ignore[arg-type]
+            lit(ddof),
+        )
+
+    else:
+
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
 
 
 def var(col: Column, ddof: int) -> Column:
-    sc = SparkContext._active_spark_context
-    return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
+    if is_remote():
+        from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+        return _invoke_function_over_columns(  # type: ignore[return-value]
+            "pandas_var",
+            col,  # type: ignore[arg-type]
+            lit(ddof),
+        )
+
+    else:
+
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))
 
 
 def skew(col: Column) -> Column:
diff --git a/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py b/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py
index e53de1b0720..19af61c0cef 100644
--- a/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py
+++ b/python/pyspark/pandas/tests/connect/frame/test_parity_constructor.py
@@ -16,7 +16,6 @@
 #
 import unittest
 
-from pyspark import pandas as ps
 from pyspark.pandas.tests.frame.test_constructor import FrameConstructorMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
@@ -25,13 +24,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 class FrameParityConstructorTests(
     FrameConstructorMixin, PandasOnSparkTestUtils, ReusedConnectTestCase
 ):
-    @property
-    def psdf(self):
-        return ps.from_pandas(self.pdf)
-
-    @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
-    def test_dataframe(self):
-        super().test_dataframe()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py
index e6f91b0f17f..1ca101ef545 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_missing_data.py
@@ -42,12 +42,6 @@ class GroupbyParityMissingDataTests(
     def test_fillna(self):
         super().test_fillna()
 
-    @unittest.skip(
-        "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
-    )
-    def test_dropna(self):
-        super().test_dropna()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.groupby.test_parity_missing_data import *  # noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
index c3fe251e616..75f1ed41d61 100644
--- a/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_stat.py
@@ -22,17 +22,8 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
 class GroupbyParityStatTests(GroupbyStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
-    @unittest.skip(
-        "TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.covar in Spark Connect."
-    )
-    def test_basic_stat_funcs(self):
-        super().test_basic_stat_funcs()
-
-    @unittest.skip(
-        "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
-    )
-    def test_ddof(self):
-        super().test_ddof()
+
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
index 76a5c8d30ea..af165eee8f5 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
@@ -28,7 +28,9 @@ class GenericFunctionsParityTests(
     def test_interpolate(self):
         super().test_interpolate()
 
-    @unittest.skip("TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.std in Spark Connect.")
+    @unittest.skip(
+        "TODO(SPARK-43627): Enable pyspark.pandas.spark.functions.skew in Spark Connect."
+    )
     def test_stat_functions(self):
         super().test_stat_functions()
 
diff --git a/python/pyspark/pandas/tests/connect/test_parity_stats.py b/python/pyspark/pandas/tests/connect/test_parity_stats.py
index ae4597c3638..f72eec9dba4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_stats.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_stats.py
@@ -35,26 +35,20 @@ class StatsParityTests(StatsTestsMixin, PandasOnSparkTestUtils, ReusedConnectTes
         super().test_skew_kurt_numerical_stability()
 
     @unittest.skip(
-        "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
+        "TODO(SPARK-43627): Enable pyspark.pandas.spark.functions.skew in Spark Connect."
     )
     def test_stat_functions(self):
         super().test_stat_functions()
 
     @unittest.skip(
-        "TODO(SPARK-43645): Enable pyspark.pandas.spark.functions.stddev in Spark Connect."
+        "TODO(SPARK-43627): Enable pyspark.pandas.spark.functions.skew in Spark Connect."
     )
     def test_stat_functions_multiindex_column(self):
         super().test_stat_functions_multiindex_column()
 
-    @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
-    def test_stats_on_boolean_dataframe(self):
-        super().test_stats_on_boolean_dataframe()
-
-    @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
-    def test_stats_on_boolean_series(self):
-        super().test_stats_on_boolean_series()
-
-    @unittest.skip("TODO(SPARK-43622): Enable pyspark.pandas.spark.functions.var in Spark Connect.")
+    @unittest.skip(
+        "TODO(SPARK-43626): Enable pyspark.pandas.spark.functions.kurt in Spark Connect."
+    )
     def test_stats_on_non_numeric_columns_should_be_discarded_if_numeric_only_is_true(self):
         super().test_stats_on_non_numeric_columns_should_be_discarded_if_numeric_only_is_true()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org