You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/27 21:05:58 UTC
[spark] branch master updated: [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 20a8fc87d67 [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect
20a8fc87d67 is described below
commit 20a8fc87d67c842ac3386dc6ae0c53a9533900c2
Author: itholic <ha...@databricks.com>
AuthorDate: Tue Jun 27 14:05:42 2023 -0700
[SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect
### What changes were proposed in this pull request?
This PR proposes to add `LastNonNull` and `NullIndex` to SparkConnectPlanner to enable `Series.interpolate`.
### Why are the changes needed?
To increase pandas API coverage
### Does this PR introduce _any_ user-facing change?
Yes, `Series.interpolate` will be available from this fix.
### How was this patch tested?
Reusing the existing UT.
Closes #41670 from itholic/interpolate.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 8 +++++++
python/pyspark/pandas/series.py | 9 ++++---
python/pyspark/pandas/spark/functions.py | 28 ++++++++++++++++++++++
.../tests/connect/test_parity_generic_functions.py | 4 +++-
python/pyspark/sql/utils.py | 14 ++++++++++-
5 files changed, 56 insertions(+), 7 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c19fc5fe90e..ff158990560 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1768,6 +1768,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val ignoreNA = extractBoolean(children(2), "ignoreNA")
Some(EWM(children(0), alpha, ignoreNA))
+ case "last_non_null" if fun.getArgumentsCount == 1 =>
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ Some(LastNonNull(children(0)))
+
+ case "null_index" if fun.getArgumentsCount == 1 =>
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ Some(NullIndex(children(0)))
+
// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
val expr = transformExpression(fun.getArguments(0))
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 0f1e814946a..95ca92e7878 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -53,7 +53,6 @@ from pandas.api.types import ( # type: ignore[attr-defined]
CategoricalDtype,
)
from pandas.tseries.frequencies import DateOffset
-from pyspark import SparkContext
from pyspark.sql import functions as F, Column as PySparkColumn, DataFrame as SparkDataFrame
from pyspark.sql.types import (
ArrayType,
@@ -70,7 +69,7 @@ from pyspark.sql.types import (
TimestampType,
)
from pyspark.sql.window import Window
-from pyspark.sql.utils import get_column_class
+from pyspark.sql.utils import get_column_class, get_window_class
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
@@ -2257,10 +2256,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
return self._psdf.copy()._psser_for(self._column_label)
scol = self.spark.column
- sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
- last_non_null = PySparkColumn(sql_utils.lastNonNull(scol._jc))
- null_index = PySparkColumn(sql_utils.nullIndex(scol._jc))
+ last_non_null = SF.last_non_null(scol)
+ null_index = SF.null_index(scol)
+ Window = get_window_class()
window_forward = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py
index 06d5692238d..44650fd4d20 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -157,3 +157,31 @@ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))
+
+
+def last_non_null(col: Column) -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
+
+ return _invoke_function_over_columns( # type: ignore[return-value]
+ "last_non_null",
+ col, # type: ignore[arg-type]
+ )
+
+ else:
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))
+
+
+def null_index(col: Column) -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions import _invoke_function_over_columns
+
+ return _invoke_function_over_columns( # type: ignore[return-value]
+ "null_index",
+ col, # type: ignore[arg-type]
+ )
+
+ else:
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
index d2c05893ae2..1bf2650d874 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py
@@ -24,7 +24,9 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
class GenericFunctionsParityTests(
GenericFunctionsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
):
- @unittest.skip("TODO(SPARK-43631): Enable Series.interpolate with Spark Connect.")
+ @unittest.skip(
+ "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
+ )
def test_interpolate(self):
super().test_interpolate()
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 7ecfa65dcd1..608ed7e9ac9 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -46,6 +46,7 @@ if TYPE_CHECKING:
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.column import Column
+ from pyspark.sql.window import Window
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
has_numpy = False
@@ -188,7 +189,7 @@ def try_remote_window(f: FuncT) -> FuncT:
def wrapped(*args: Any, **kwargs: Any) -> Any:
if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- from pyspark.sql.connect.window import Window
+ from pyspark.sql.connect.window import Window # type: ignore[misc]
return getattr(Window, f.__name__)(*args, **kwargs)
else:
@@ -282,3 +283,14 @@ def get_dataframe_class() -> Type["DataFrame"]:
return ConnectDataFrame # type: ignore[return-value]
else:
return PySparkDataFrame
+
+
+def get_window_class() -> Type["Window"]:
+ from pyspark.sql.window import Window as PySparkWindow
+
+ if is_remote():
+ from pyspark.sql.connect.window import Window as ConnectWindow
+
+ return ConnectWindow # type: ignore[return-value]
+ else:
+ return PySparkWindow
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org