You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/20 00:15:39 UTC
[spark] branch master updated: [SPARK-43624][PS][CONNECT] Add `EWM` to SparkConnectPlanner
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b4f4c372317 [SPARK-43624][PS][CONNECT] Add `EWM` to SparkConnectPlanner
b4f4c372317 is described below
commit b4f4c37231752d2eb6688b05e21410b3e823b427
Author: itholic <ha...@databricks.com>
AuthorDate: Tue Jun 20 08:15:24 2023 +0800
[SPARK-43624][PS][CONNECT] Add `EWM` to SparkConnectPlanner
### What changes were proposed in this pull request?
This PR proposes to add `EWM` for SparkConnectPlanner.
### Why are the changes needed?
To increase pandas API coverage
### Does this PR introduce _any_ user-facing change?
No, we added `EWM` to SparkConnectPlanner, but there is still unresolved `AnalysisException` issues(SPARK-43611) that need to be addressed in follow-up work.
### How was this patch tested?
Manually checked the plan was created with EWM properly.
Closes #41660 from itholic/EWM.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 11 +++++++++++
.../pyspark/pandas/tests/connect/test_parity_ewm.py | 8 ++++++--
python/pyspark/pandas/window.py | 20 ++++++++++++++++++--
3 files changed, 35 insertions(+), 4 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b02b49d00dc..dc819fb4020 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1682,6 +1682,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val ignoreNA = extractBoolean(children(1), "ignoreNA")
Some(aggregate.PandasMode(children(0), ignoreNA).toAggregateExpression(false))
+ case "ewm" if fun.getArgumentsCount == 3 =>
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ val alpha = extractDouble(children(1), "alpha")
+ val ignoreNA = extractBoolean(children(2), "ignoreNA")
+ Some(EWM(children(0), alpha, ignoreNA))
+
// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
val expr = transformExpression(fun.getArguments(0))
@@ -1742,6 +1748,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
case other => throw InvalidPlanInput(s"$field should be a literal boolean, but got $other")
}
+ private def extractDouble(expr: Expression, field: String): Double = expr match {
+ case Literal(double: Double, DoubleType) => double
+ case other => throw InvalidPlanInput(s"$field should be a literal double, but got $other")
+ }
+
private def extractInteger(expr: Expression, field: String): Int = expr match {
case Literal(int: Int, IntegerType) => int
case other => throw InvalidPlanInput(s"$field should be a literal integer, but got $other")
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/connect/test_parity_ewm.py
index 0e13306fd79..e079f847296 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_ewm.py
@@ -22,11 +22,15 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils):
- @unittest.skip("TODO(SPARK-43624): Enable ExponentialMovingLike.mean with Spark Connect.")
+ @unittest.skip(
+ "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+ )
def test_ewm_mean(self):
super().test_ewm_mean()
- @unittest.skip("TODO(SPARK-43624): Enable ExponentialMovingLike.mean with Spark Connect.")
+ @unittest.skip(
+ "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+ )
def test_groupby_ewm_func(self):
super().test_groupby_ewm_func()
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 316a4af92dd..8d09dd132ca 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -44,6 +44,7 @@ from pyspark.sql.types import (
DoubleType,
)
from pyspark.sql.window import WindowSpec
+from pyspark.sql.utils import is_remote
class RollingAndExpanding(Generic[FrameLike], metaclass=ABCMeta):
@@ -2448,11 +2449,26 @@ class ExponentialMovingLike(Generic[FrameLike], metaclass=ABCMeta):
unified_alpha = self._compute_unified_alpha()
def mean(scol: Column) -> Column:
- sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+ if is_remote():
+ from pyspark.sql.connect.functions import _invoke_function_over_columns, lit
+
+ col_ewm = _invoke_function_over_columns(
+ "ewm",
+ scol, # type: ignore[arg-type]
+ lit(unified_alpha),
+ lit(self._ignore_na),
+ )
+ else:
+ sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+ col_ewm = Column(
+ sql_utils.ewm(
+ scol._jc, unified_alpha, self._ignore_na # type: ignore[assignment]
+ )
+ )
return F.when(
F.count(F.when(~scol.isNull(), 1).otherwise(None)).over(self._unbounded_window)
>= self._min_periods,
- Column(sql_utils.ewm(scol._jc, unified_alpha, self._ignore_na)).over(self._window),
+ col_ewm.over(self._window), # type: ignore[arg-type]
).otherwise(F.lit(None))
return self._apply_as_series_or_frame(mean)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org