You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/20 00:15:39 UTC

[spark] branch master updated: [SPARK-43624][PS][CONNECT] Add `EWM` to SparkConnectPlanner

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b4f4c372317 [SPARK-43624][PS][CONNECT] Add `EWM` to SparkConnectPlanner
b4f4c372317 is described below

commit b4f4c37231752d2eb6688b05e21410b3e823b427
Author: itholic <ha...@databricks.com>
AuthorDate: Tue Jun 20 08:15:24 2023 +0800

    [SPARK-43624][PS][CONNECT] Add `EWM` to SparkConnectPlanner
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add `EWM` for SparkConnectPlanner.
    
    ### Why are the changes needed?
    
    To increase pandas API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, we added `EWM` to SparkConnectPlanner, but there is still unresolved `AnalysisException` issues(SPARK-43611) that need to be addressed in follow-up work.
    
    ### How was this patch tested?
    
    Manually checked the plan was created with EWM properly.
    
    Closes #41660 from itholic/EWM.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala    | 11 +++++++++++
 .../pyspark/pandas/tests/connect/test_parity_ewm.py  |  8 ++++++--
 python/pyspark/pandas/window.py                      | 20 ++++++++++++++++++--
 3 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b02b49d00dc..dc819fb4020 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1682,6 +1682,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
         val ignoreNA = extractBoolean(children(1), "ignoreNA")
         Some(aggregate.PandasMode(children(0), ignoreNA).toAggregateExpression(false))
 
+      case "ewm" if fun.getArgumentsCount == 3 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val alpha = extractDouble(children(1), "alpha")
+        val ignoreNA = extractBoolean(children(2), "ignoreNA")
+        Some(EWM(children(0), alpha, ignoreNA))
+
       // ML-specific functions
       case "vector_to_array" if fun.getArgumentsCount == 2 =>
         val expr = transformExpression(fun.getArguments(0))
@@ -1742,6 +1748,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
     case other => throw InvalidPlanInput(s"$field should be a literal boolean, but got $other")
   }
 
+  private def extractDouble(expr: Expression, field: String): Double = expr match {
+    case Literal(double: Double, DoubleType) => double
+    case other => throw InvalidPlanInput(s"$field should be a literal double, but got $other")
+  }
+
   private def extractInteger(expr: Expression, field: String): Int = expr match {
     case Literal(int: Int, IntegerType) => int
     case other => throw InvalidPlanInput(s"$field should be a literal integer, but got $other")
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/connect/test_parity_ewm.py
index 0e13306fd79..e079f847296 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_ewm.py
@@ -22,11 +22,15 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
 
 
 class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils):
-    @unittest.skip("TODO(SPARK-43624): Enable ExponentialMovingLike.mean with Spark Connect.")
+    @unittest.skip(
+        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+    )
     def test_ewm_mean(self):
         super().test_ewm_mean()
 
-    @unittest.skip("TODO(SPARK-43624): Enable ExponentialMovingLike.mean with Spark Connect.")
+    @unittest.skip(
+        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+    )
     def test_groupby_ewm_func(self):
         super().test_groupby_ewm_func()
 
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 316a4af92dd..8d09dd132ca 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -44,6 +44,7 @@ from pyspark.sql.types import (
     DoubleType,
 )
 from pyspark.sql.window import WindowSpec
+from pyspark.sql.utils import is_remote
 
 
 class RollingAndExpanding(Generic[FrameLike], metaclass=ABCMeta):
@@ -2448,11 +2449,26 @@ class ExponentialMovingLike(Generic[FrameLike], metaclass=ABCMeta):
         unified_alpha = self._compute_unified_alpha()
 
         def mean(scol: Column) -> Column:
-            sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+            if is_remote():
+                from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+                col_ewm = _invoke_function_over_columns(
+                    "ewm",
+                    scol,  # type: ignore[arg-type]
+                    lit(unified_alpha),
+                    lit(self._ignore_na),
+                )
+            else:
+                sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+                col_ewm = Column(
+                    sql_utils.ewm(
+                        scol._jc, unified_alpha, self._ignore_na  # type: ignore[assignment]
+                    )
+                )
             return F.when(
                 F.count(F.when(~scol.isNull(), 1).otherwise(None)).over(self._unbounded_window)
                 >= self._min_periods,
-                Column(sql_utils.ewm(scol._jc, unified_alpha, self._ignore_na)).over(self._window),
+                col_ewm.over(self._window),  # type: ignore[arg-type]
             ).otherwise(F.lit(None))
 
         return self._apply_as_series_or_frame(mean)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org