You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/10 01:33:30 UTC

[spark] branch master updated: [SPARK-43616][PS][CONNECT] Enable `pyspark.pandas.spark.functions.mode` in Spark Connect

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0dc2bd45eee [SPARK-43616][PS][CONNECT] Enable `pyspark.pandas.spark.functions.mode` in Spark Connect
0dc2bd45eee is described below

commit 0dc2bd45eeed869e2cca58e35da59b31d5f87caa
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Sat Jun 10 10:33:15 2023 +0900

    [SPARK-43616][PS][CONNECT] Enable `pyspark.pandas.spark.functions.mode` in Spark Connect
    
    ### What changes were proposed in this pull request?
    Enable `pyspark.pandas.spark.functions.mode` in Spark Connect
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new function enabled
    
    ### How was this patch tested?
    enabled UT
    
    Closes #41523 from zhengruifeng/ps_con_mode.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../spark/sql/connect/planner/SparkConnectPlanner.scala    |  8 ++++++--
 python/pyspark/pandas/spark/functions.py                   | 14 ++++++++++++--
 .../tests/connect/computation/test_parity_compute.py       |  4 +---
 .../pandas/tests/connect/series/test_parity_stat.py        |  6 ------
 4 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 27d55e14b66..5ad2b4b5cc7 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -44,7 +44,6 @@ import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIden
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.PandasCovar
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
@@ -1638,7 +1637,12 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging {
       case "pandas_covar" if fun.getArgumentsCount == 3 =>
         val children = fun.getArgumentsList.asScala.map(transformExpression)
         val ddof = extractInteger(children(2), "ddof")
-        Some(PandasCovar(children(0), children(1), ddof).toAggregateExpression(false))
+        Some(aggregate.PandasCovar(children(0), children(1), ddof).toAggregateExpression(false))
+
+      case "pandas_mode" if fun.getArgumentsCount == 2 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val ignoreNA = extractBoolean(children(1), "ignoreNA")
+        Some(aggregate.PandasMode(children(0), ignoreNA).toAggregateExpression(false))
 
       // ML-specific functions
       case "vector_to_array" if fun.getArgumentsCount == 2 =>
diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py
index b669f5bab1e..f4c0261134e 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -56,8 +56,18 @@ def kurt(col: Column) -> Column:
 
 
 def mode(col: Column, dropna: bool) -> Column:
-    sc = SparkContext._active_spark_context
-    return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
+    if is_remote():
+        from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+        return _invoke_function_over_columns(  # type: ignore[return-value]
+            "pandas_mode",
+            col,  # type: ignore[arg-type]
+            lit(dropna),
+        )
+
+    else:
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
 
 
 def covar(col1: Column, col2: Column, ddof: int) -> Column:
diff --git a/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py b/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py
index 11a60b31559..181dd309d3f 100644
--- a/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py
+++ b/python/pyspark/pandas/tests/connect/computation/test_parity_compute.py
@@ -33,9 +33,7 @@ class FrameParityComputeTests(FrameComputeMixin, PandasOnSparkTestUtils, ReusedC
     def test_diff(self):
         super().test_diff()
 
-    @unittest.skip(
-        "TODO(SPARK-43616): Enable pyspark.pandas.spark.functions.mode in Spark Connect."
-    )
+    @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
     def test_mode(self):
         super().test_mode()
 
diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_stat.py b/python/pyspark/pandas/tests/connect/series/test_parity_stat.py
index a05d70e4c8f..ea2e57a6994 100644
--- a/python/pyspark/pandas/tests/connect/series/test_parity_stat.py
+++ b/python/pyspark/pandas/tests/connect/series/test_parity_stat.py
@@ -22,12 +22,6 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
 class SeriesParityStatTests(SeriesStatMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
-    @unittest.skip(
-        "TODO(SPARK-43616): Enable pyspark.pandas.spark.functions.mode in Spark Connect."
-    )
-    def test_mode(self):
-        super().test_mode()
-
     @unittest.skip(
         "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
     )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org