You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/14 00:38:37 UTC
[spark] branch master updated: [SPARK-43684][SPARK-43685][SPARK-43686][SPARK-43691][CONNECT][PS] Fix `(NullOps|NumOps).(eq|ne)` for Spark Connect
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9c8592058cd [SPARK-43684][SPARK-43685][SPARK-43686][SPARK-43691][CONNECT][PS] Fix `(NullOps|NumOps).(eq|ne)` for Spark Connect
9c8592058cd is described below
commit 9c8592058cd1a5cc530fb30e6dc7c5c759ad528d
Author: itholic <ha...@databricks.com>
AuthorDate: Wed Jun 14 09:38:22 2023 +0900
[SPARK-43684][SPARK-43685][SPARK-43686][SPARK-43691][CONNECT][PS] Fix `(NullOps|NumOps).(eq|ne)` for Spark Connect
### What changes were proposed in this pull request?
This PR proposes to fix `NullOps.(eq|ne)` and `NumOps.(eq|ne)` for pandas API on Spark with Spark Connect.
This includes SPARK-43684, SPARK-43685, SPARK-43686, SPARK-43691 at once, because they are all related similar modifications in single file.
This PR also introduce new util function `_is_extension_dtypes` to check whether the given object is a type of extension dtype or not, and apply to all related functions.
### Why are the changes needed?
The reason is that pandas API on Spark with Spark Connect operates differently from pandas as below:
**For `ne`:**
```python
>>> pser = pd.Series([1.0, 2.0, np.nan])
>>> psser = ps.from_pandas(pser)
>>> pser.ne(pser)
0 False
1 False
2 True
dtype: bool
>>> psser.ne(psser)
0 False
1 False
2 None
dtype: bool
```
We expect `True` for non-equal case, but it returns `None` in Spark Connect. So we should cast `None` to `True` for `ne`.
**For `eq`:**
```python
>>> pser = pd.Series([1.0, 2.0, np.nan])
>>> psser = ps.from_pandas(pser)
>>> pser.eq(pser)
0 True
1 True
2 False
dtype: bool
>>> psser.eq(psser)
0 True
1 True
2 None
dtype: bool
```
We expect `False` for non-equal case, but it returns `None` in Spark Connect. So we should cast `None` to `False` for `eq`.
### Does this PR introduce _any_ user-facing change?
Yes, `NullOps.eq`, `NullOps.ne`, `NumOps.eq`, `NumOps.ne` are now working as expected on Spark Connect.
### How was this patch tested?
Uncomment the UTs, tested manually for vanilla PySpark.
Closes #41514 from itholic/SPARK-43684.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/data_type_ops/base.py | 9 +++++
python/pyspark/pandas/data_type_ops/binary_ops.py | 8 ++---
.../pandas/data_type_ops/categorical_ops.py | 6 ++--
.../pyspark/pandas/data_type_ops/datetime_ops.py | 8 ++---
python/pyspark/pandas/data_type_ops/null_ops.py | 39 ++++++++++------------
python/pyspark/pandas/data_type_ops/num_ops.py | 35 +++++++++----------
python/pyspark/pandas/data_type_ops/string_ops.py | 8 ++---
.../pyspark/pandas/data_type_ops/timedelta_ops.py | 12 +++----
.../connect/data_type_ops/test_parity_null_ops.py | 8 -----
.../connect/data_type_ops/test_parity_num_ops.py | 8 -----
.../pandas/tests/data_type_ops/test_null_ops.py | 1 +
python/pyspark/sql/utils.py | 16 ++++++---
12 files changed, 75 insertions(+), 83 deletions(-)
diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py
index d88eddee26b..18e792e292f 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -219,6 +219,15 @@ def _is_boolean_type(right: Any) -> bool:
)
+def _is_extension_dtypes(object: Any) -> bool:
+ """
+ Check whether the type of given object is extension dtype or not.
+ Extention dtype includes Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype, BooleanDtype,
+ StringDtype, Float32Dtype and Float64Dtype.
+ """
+ return isinstance(getattr(object, "dtype", None), extension_dtypes)
+
+
class DataTypeOps(object, metaclass=ABCMeta):
"""The base class for binary operations of pandas-on-Spark objects (of different data types)."""
diff --git a/python/pyspark/pandas/data_type_ops/binary_ops.py b/python/pyspark/pandas/data_type_ops/binary_ops.py
index ba31156178a..f528d3e9ae2 100644
--- a/python/pyspark/pandas/data_type_ops/binary_ops.py
+++ b/python/pyspark/pandas/data_type_ops/binary_ops.py
@@ -69,19 +69,19 @@ class BinaryOps(DataTypeOps):
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__lt__")(left, right)
+ return pyspark_column_op("__lt__", left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__le__")(left, right)
+ return pyspark_column_op("__le__", left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__ge__")(left, right)
+ return pyspark_column_op("__ge__", left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__gt__")(left, right)
+ return pyspark_column_op("__gt__", left, right)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index 66e181a6079..824666b5819 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -117,15 +117,15 @@ def _compare(
if hash(left.dtype) != hash(right.dtype):
raise TypeError("Categoricals can only be compared if 'categories' are the same.")
if cast(CategoricalDtype, left.dtype).ordered:
- return pyspark_column_op(func_name)(left, right)
+ return pyspark_column_op(func_name, left, right)
else:
- return pyspark_column_op(func_name)(_to_cat(left), _to_cat(right))
+ return pyspark_column_op(func_name, _to_cat(left), _to_cat(right))
elif not is_list_like(right):
categories = cast(CategoricalDtype, left.dtype).categories
if right not in categories:
raise TypeError("Cannot compare a Categorical with a scalar, which is not a category.")
right_code = categories.get_loc(right)
- return pyspark_column_op(func_name)(left, right_code)
+ return pyspark_column_op(func_name, left, right_code)
else:
raise TypeError("Cannot compare a Categorical with the given type.")
diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py
index c5f4df96bde..ea9b994076b 100644
--- a/python/pyspark/pandas/data_type_ops/datetime_ops.py
+++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py
@@ -111,19 +111,19 @@ class DatetimeOps(DataTypeOps):
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__lt__")(left, right)
+ return pyspark_column_op("__lt__", left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__le__")(left, right)
+ return pyspark_column_op("__le__", left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__ge__")(left, right)
+ return pyspark_column_op("__ge__", left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__gt__")(left, right)
+ return pyspark_column_op("__gt__", left, right)
def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
diff --git a/python/pyspark/pandas/data_type_ops/null_ops.py b/python/pyspark/pandas/data_type_ops/null_ops.py
index ab86f074b99..329a3790df6 100644
--- a/python/pyspark/pandas/data_type_ops/null_ops.py
+++ b/python/pyspark/pandas/data_type_ops/null_ops.py
@@ -17,7 +17,7 @@
from typing import Any, Union
-from pandas.api.types import CategoricalDtype
+from pandas.api.types import CategoricalDtype, is_list_like # type: ignore[attr-defined]
from pyspark.pandas._typing import Dtype, IndexOpsLike
from pyspark.pandas.data_type_ops.base import (
@@ -31,7 +31,8 @@ from pyspark.pandas.data_type_ops.base import (
from pyspark.pandas._typing import SeriesOrIndex
from pyspark.pandas.typedef import pandas_on_spark_type
from pyspark.sql.types import BooleanType, StringType
-from pyspark.sql.utils import pyspark_column_op, is_remote
+from pyspark.sql.utils import pyspark_column_op
+from pyspark.pandas.base import IndexOpsMixin
class NullOps(DataTypeOps):
@@ -43,37 +44,31 @@ class NullOps(DataTypeOps):
def pretty_name(self) -> str:
return "nulls"
+ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
+ # We can directly use `super().eq` when given object is list, tuple, dict or set.
+ if not isinstance(right, IndexOpsMixin) and is_list_like(right):
+ return super().eq(left, right)
+ return pyspark_column_op("__eq__", left, right, fillna=False)
+
+ def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
+ _sanitize_list_like(right)
+ return pyspark_column_op("__ne__", left, right, fillna=True)
+
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__lt__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__lt__", left, right, fillna=False)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__le__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__le__", left, right, fillna=False)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__ge__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__ge__", left, right, fillna=False)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__gt__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__gt__", left, right, fillna=False)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py
index 9e7e2037a92..911228a5265 100644
--- a/python/pyspark/pandas/data_type_ops/num_ops.py
+++ b/python/pyspark/pandas/data_type_ops/num_ops.py
@@ -24,6 +24,7 @@ from pandas.api.types import ( # type: ignore[attr-defined]
is_bool_dtype,
is_integer_dtype,
CategoricalDtype,
+ is_list_like,
)
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
@@ -213,37 +214,31 @@ class NumericOps(DataTypeOps):
F.abs(operand.spark.column), field=operand._internal.data_fields[0]
)
+ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
+ # We can directly use `super().eq` when given object is list, tuple, dict or set.
+ if not isinstance(right, IndexOpsMixin) and is_list_like(right):
+ return super().eq(left, right)
+ return pyspark_column_op("__eq__", left, right, fillna=False)
+
+ def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
+ _sanitize_list_like(right)
+ return pyspark_column_op("__ne__", left, right, fillna=True)
+
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__lt__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__lt__", left, right, fillna=False)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__le__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__le__", left, right, fillna=False)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__ge__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__ge__", left, right, fillna=False)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- result = pyspark_column_op("__gt__")(left, right)
- if is_remote():
- # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
- result = result.fillna(False)
- return result
+ return pyspark_column_op("__gt__", left, right, fillna=False)
class IntegralOps(NumericOps):
diff --git a/python/pyspark/pandas/data_type_ops/string_ops.py b/python/pyspark/pandas/data_type_ops/string_ops.py
index e5818cb4635..1c282f20117 100644
--- a/python/pyspark/pandas/data_type_ops/string_ops.py
+++ b/python/pyspark/pandas/data_type_ops/string_ops.py
@@ -105,19 +105,19 @@ class StringOps(DataTypeOps):
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__lt__")(left, right)
+ return pyspark_column_op("__lt__", left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__le__")(left, right)
+ return pyspark_column_op("__le__", left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__ge__")(left, right)
+ return pyspark_column_op("__ge__", left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__gt__")(left, right)
+ return pyspark_column_op("__gt__", left, right)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/data_type_ops/timedelta_ops.py b/python/pyspark/pandas/data_type_ops/timedelta_ops.py
index 3e96ebbb13a..7a9da8511e6 100644
--- a/python/pyspark/pandas/data_type_ops/timedelta_ops.py
+++ b/python/pyspark/pandas/data_type_ops/timedelta_ops.py
@@ -72,7 +72,7 @@ class TimedeltaOps(DataTypeOps):
and isinstance(right.spark.data_type, DayTimeIntervalType)
or isinstance(right, timedelta)
):
- return pyspark_column_op("__sub__")(left, right)
+ return pyspark_column_op("__sub__", left, right)
else:
raise TypeError("Timedelta subtraction can only be applied to timedelta series.")
@@ -80,22 +80,22 @@ class TimedeltaOps(DataTypeOps):
_sanitize_list_like(right)
if isinstance(right, timedelta):
- return pyspark_column_op("__rsub__")(left, right)
+ return pyspark_column_op("__rsub__", left, right)
else:
raise TypeError("Timedelta subtraction can only be applied to timedelta series.")
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__lt__")(left, right)
+ return pyspark_column_op("__lt__", left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__le__")(left, right)
+ return pyspark_column_op("__le__", left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__ge__")(left, right)
+ return pyspark_column_op("__ge__", left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- return pyspark_column_op("__gt__")(left, right)
+ return pyspark_column_op("__gt__", left, right)
diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
index 1b53a064971..63b53c02fd7 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
@@ -29,14 +29,6 @@ class NullOpsParityTests(
def test_astype(self):
super().test_astype()
- @unittest.skip("TODO(SPARK-43684): Fix NullOps.eq to work with Spark Connect Column.")
- def test_eq(self):
- super().test_eq()
-
- @unittest.skip("TODO(SPARK-43685): Fix NullOps.ne to work with Spark Connect Column.")
- def test_ne(self):
- super().test_ne()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_null_ops import * # noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
index b65873c6ab5..04aa24c4045 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
@@ -34,14 +34,6 @@ class NumOpsParityTests(
def test_astype(self):
super().test_astype()
- @unittest.skip("TODO(SPARK-43686): Enable NumOpsParityTests.test_eq.")
- def test_eq(self):
- super().test_eq()
-
- @unittest.skip("TODO(SPARK-43691): Enable NumOpsParityTests.test_ne.")
- def test_ne(self):
- super().test_ne()
-
if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops import * # noqa: F401
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py
index 22ea26050bf..19a3e7c0735 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py
@@ -138,6 +138,7 @@ class NullOpsTestsMixin:
def test_eq(self):
pser, psser = self.pser, self.psser
self.assert_eq(pser == pser, psser == psser)
+ self.assert_eq(pser == [None, 1, None], psser == [None, 1, None])
def test_ne(self):
pser, psser = self.pser, self.psser
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 841ceb4fa1d..f5a5c88b8d3 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -16,7 +16,7 @@
#
import functools
import os
-from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar
+from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union
from py4j.java_collections import JavaArray
from py4j.java_gateway import (
@@ -45,7 +45,7 @@ from pyspark.find_spark_home import _find_spark_home
if TYPE_CHECKING:
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
- from pyspark.pandas._typing import SeriesOrIndex
+ from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
has_numpy = False
try:
@@ -237,12 +237,15 @@ def try_remote_observation(f: FuncT) -> FuncT:
return cast(FuncT, wrapped)
-def pyspark_column_op(func_name: str) -> Callable[..., "SeriesOrIndex"]:
+def pyspark_column_op(
+ func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None
+) -> Union["SeriesOrIndex", None]:
"""
Wrapper function for column_op to get proper Column class.
"""
from pyspark.pandas.base import column_op
from pyspark.sql.column import Column as PySparkColumn
+ from pyspark.pandas.data_type_ops.base import _is_extension_dtypes
if is_remote():
from pyspark.sql.connect.column import Column as ConnectColumn
@@ -250,4 +253,9 @@ def pyspark_column_op(func_name: str) -> Callable[..., "SeriesOrIndex"]:
Column = ConnectColumn
else:
Column = PySparkColumn # type: ignore[assignment]
- return column_op(getattr(Column, func_name))
+ result = column_op(getattr(Column, func_name))(left, right)
+ # It works as expected on extension dtype, so we don't need to call `fillna` for this case.
+ if (fillna is not None) and (_is_extension_dtypes(left) or _is_extension_dtypes(right)):
+ fillna = None
+ # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
+ return result.fillna(fillna) if fillna is not None else result
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org