You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/27 21:05:58 UTC

[spark] branch master updated: [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 20a8fc87d67 [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect
20a8fc87d67 is described below

commit 20a8fc87d67c842ac3386dc6ae0c53a9533900c2
Author: itholic <ha...@databricks.com>
AuthorDate: Tue Jun 27 14:05:42 2023 -0700

    [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add `LastNonNull` and `NullIndex` to SparkConnectPlanner to enable `Series.interpolate`.
    
    ### Why are the changes needed?
    
    To increase pandas API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, `Series.interpolate` will be available from this fix.
    
    ### How was this patch tested?
    
    Reusing the existing UT.
    
    Closes #41670 from itholic/interpolate.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  8 +++++++
 python/pyspark/pandas/series.py                    |  9 ++++---
 python/pyspark/pandas/spark/functions.py           | 28 ++++++++++++++++++++++
 .../tests/connect/test_parity_generic_functions.py |  4 +++-
 python/pyspark/sql/utils.py                        | 14 ++++++++++-
 5 files changed, 56 insertions(+), 7 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c19fc5fe90e..ff158990560 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1768,6 +1768,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
         val ignoreNA = extractBoolean(children(2), "ignoreNA")
         Some(EWM(children(0), alpha, ignoreNA))
 
+      case "last_non_null" if fun.getArgumentsCount == 1 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        Some(LastNonNull(children(0)))
+
+      case "null_index" if fun.getArgumentsCount == 1 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        Some(NullIndex(children(0)))
+
       // ML-specific functions
       case "vector_to_array" if fun.getArgumentsCount == 2 =>
         val expr = transformExpression(fun.getArguments(0))
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 0f1e814946a..95ca92e7878 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -53,7 +53,6 @@ from pandas.api.types import (  # type: ignore[attr-defined]
     CategoricalDtype,
 )
 from pandas.tseries.frequencies import DateOffset
-from pyspark import SparkContext
 from pyspark.sql import functions as F, Column as PySparkColumn, DataFrame as SparkDataFrame
 from pyspark.sql.types import (
     ArrayType,
@@ -70,7 +69,7 @@ from pyspark.sql.types import (
     TimestampType,
 )
 from pyspark.sql.window import Window
-from pyspark.sql.utils import get_column_class
+from pyspark.sql.utils import get_column_class, get_window_class
 
 from pyspark import pandas as ps  # For running doctests and reference resolution in PyCharm.
 from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
@@ -2257,10 +2256,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             return self._psdf.copy()._psser_for(self._column_label)
 
         scol = self.spark.column
-        sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
-        last_non_null = PySparkColumn(sql_utils.lastNonNull(scol._jc))
-        null_index = PySparkColumn(sql_utils.nullIndex(scol._jc))
+        last_non_null = SF.last_non_null(scol)
+        null_index = SF.null_index(scol)
 
+        Window = get_window_class()
         window_forward = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
             Window.unboundedPreceding, Window.currentRow
         )
diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py
index 06d5692238d..44650fd4d20 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -157,3 +157,31 @@ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
     else:
         sc = SparkContext._active_spark_context
         return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
+
+
+def last_non_null(col: Column) -> Column:
+    if is_remote():
+        from pyspark.sql.connect.functions import _invoke_function_over_columns
+
+        return _invoke_function_over_columns(  # type: ignore[return-value]
+            "last_non_null",
+            col,  # type: ignore[arg-type]
+        )
+
+    else:
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
+
+
+def null_index(col: Column) -> Column:
+    if is_remote():
+        from pyspark.sql.connect.functions import _invoke_function_over_columns
+
+        return _invoke_function_over_columns(  # type: ignore[return-value]
+            "null_index",
+            col,  # type: ignore[arg-type]
+        )
+
+    else:
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
index d2c05893ae2..1bf2650d874 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
@@ -24,7 +24,9 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
 class GenericFunctionsParityTests(
     GenericFunctionsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
 ):
-    @unittest.skip("TODO(SPARK-43631): Enable Series.interpolate with Spark Connect.")
+    @unittest.skip(
+        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+    )
     def test_interpolate(self):
         super().test_interpolate()
 
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 7ecfa65dcd1..608ed7e9ac9 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -46,6 +46,7 @@ if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.column import Column
+    from pyspark.sql.window import Window
     from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
 
 has_numpy = False
@@ -188,7 +189,7 @@ def try_remote_window(f: FuncT) -> FuncT:
     def wrapped(*args: Any, **kwargs: Any) -> Any:
 
         if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
-            from pyspark.sql.connect.window import Window
+            from pyspark.sql.connect.window import Window  # type: ignore[misc]
 
             return getattr(Window, f.__name__)(*args, **kwargs)
         else:
@@ -282,3 +283,14 @@ def get_dataframe_class() -> Type["DataFrame"]:
         return ConnectDataFrame  # type: ignore[return-value]
     else:
         return PySparkDataFrame
+
+
+def get_window_class() -> Type["Window"]:
+    from pyspark.sql.window import Window as PySparkWindow
+
+    if is_remote():
+        from pyspark.sql.connect.window import Window as ConnectWindow
+
+        return ConnectWindow  # type: ignore[return-value]
+    else:
+        return PySparkWindow


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org