You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2021/08/18 18:39:33 UTC
[spark] branch master updated: [SPARK-36368][PYTHON] Fix
CategoricalOps.astype to follow pandas 1.3
This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f2e593b [SPARK-36368][PYTHON] Fix CategoricalOps.astype to follow pandas 1.3
f2e593b is described below
commit f2e593bcf1a1aa8dde9f73b77e4863ceed5a7e28
Author: itholic <ha...@databricks.com>
AuthorDate: Wed Aug 18 11:38:59 2021 -0700
[SPARK-36368][PYTHON] Fix CategoricalOps.astype to follow pandas 1.3
### What changes were proposed in this pull request?
This PR proposes to fix the behavior of `astype` for `CategoricalDtype` to follow pandas 1.3.
**Before:**
```python
>>> pcat
0 a
1 b
2 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> pcat.astype(CategoricalDtype(["b", "c", "a"]))
0 a
1 b
2 c
dtype: category
Categories (3, object): ['b', 'c', 'a']
```
**After:**
```python
>>> pcat
0 a
1 b
2 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> pcat.astype(CategoricalDtype(["b", "c", "a"]))
0 a
1 b
2 c
dtype: category
Categories (3, object): ['a', 'b', 'c'] # CategoricalDtype is not updated if dtype is the same
```
`CategoricalDtype` is treated as a same `dtype` if the unique values are the same.
```python
>>> pcat1 = pser.astype(CategoricalDtype(["b", "c", "a"]))
>>> pcat2 = pser.astype(CategoricalDtype(["a", "b", "c"]))
>>> pcat1.dtype == pcat2.dtype
True
```
### Why are the changes needed?
We should follow the latest pandas as much as possible.
### Does this PR introduce _any_ user-facing change?
Yes, the behavior is changed as example in the PR description.
### How was this patch tested?
Unittest
Closes #33757 from itholic/SPARK-36368.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Takuya UESHIN <ue...@databricks.com>
---
python/pyspark/pandas/categorical.py | 3 ++-
python/pyspark/pandas/data_type_ops/categorical_ops.py | 4 +++-
.../pandas/tests/data_type_ops/test_categorical_ops.py | 6 ++----
python/pyspark/pandas/tests/indexes/test_category.py | 16 +++++++---------
python/pyspark/pandas/tests/test_categorical.py | 16 +++++++---------
5 files changed, 21 insertions(+), 24 deletions(-)
diff --git a/python/pyspark/pandas/categorical.py b/python/pyspark/pandas/categorical.py
index 77a3cee..fa11228 100644
--- a/python/pyspark/pandas/categorical.py
+++ b/python/pyspark/pandas/categorical.py
@@ -22,6 +22,7 @@ from pandas.api.types import CategoricalDtype, is_dict_like, is_list_like
from pyspark.pandas.internal import InternalField
from pyspark.pandas.spark import functions as SF
+from pyspark.pandas.data_type_ops.categorical_ops import _to_cat
from pyspark.sql import functions as F
from pyspark.sql.types import StructField
@@ -735,7 +736,7 @@ class CategoricalAccessor(object):
return self._data.copy()
else:
dtype = CategoricalDtype(categories=new_categories, ordered=ordered)
- psser = self._data.astype(dtype)
+ psser = _to_cat(self._data).astype(dtype)
if inplace:
internal = self._data._psdf._internal.with_new_spark_column(
diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index b524cdd..c1be683 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -57,7 +57,9 @@ class CategoricalOps(DataTypeOps):
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, _ = pandas_on_spark_type(dtype)
- if isinstance(dtype, CategoricalDtype) and cast(CategoricalDtype, dtype).categories is None:
+ if isinstance(dtype, CategoricalDtype) and (
+ (dtype.categories is None) or (index_ops.dtype == dtype)
+ ):
return index_ops.copy()
return _to_cat(index_ops).astype(dtype)
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
index 11871ea..5e79eb3 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
@@ -192,13 +192,11 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[3, 1, 2])
+ # CategoricalDtype is not updated if the dtype is same from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
- # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
- pass
- elif LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
else:
- self.assert_eq(pd.Series(data).astype(cat_type), psser.astype(cat_type))
+ self.assert_eq(psser.astype(cat_type), pser)
def test_neg(self):
self.assertRaises(TypeError, lambda: -self.psser)
diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py
index 6520363..69d4667 100644
--- a/python/pyspark/pandas/tests/indexes/test_category.py
+++ b/python/pyspark/pandas/tests/indexes/test_category.py
@@ -172,25 +172,23 @@ class CategoricalIndexTest(PandasOnSparkTestCase, TestUtils):
)
pcidx = pidx.astype(CategoricalDtype(["c", "a", "b"]))
- kcidx = psidx.astype(CategoricalDtype(["c", "a", "b"]))
+ pscidx = psidx.astype(CategoricalDtype(["c", "a", "b"]))
- self.assert_eq(kcidx.astype("category"), pcidx.astype("category"))
+ self.assert_eq(pscidx.astype("category"), pcidx.astype("category"))
+ # CategoricalDtype is not updated if the dtype is same from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
- # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
- pass
- elif LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(
- kcidx.astype(CategoricalDtype(["b", "c", "a"])),
+ pscidx.astype(CategoricalDtype(["b", "c", "a"])),
pcidx.astype(CategoricalDtype(["b", "c", "a"])),
)
else:
self.assert_eq(
- kcidx.astype(CategoricalDtype(["b", "c", "a"])),
- pidx.astype(CategoricalDtype(["b", "c", "a"])),
+ pscidx.astype(CategoricalDtype(["b", "c", "a"])),
+ pcidx,
)
- self.assert_eq(kcidx.astype(str), pcidx.astype(str))
+ self.assert_eq(pscidx.astype(str), pcidx.astype(str))
def test_factorize(self):
pidx = pd.CategoricalIndex([1, 2, 3, None])
diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py
index 1335d59..1fb0d58 100644
--- a/python/pyspark/pandas/tests/test_categorical.py
+++ b/python/pyspark/pandas/tests/test_categorical.py
@@ -239,25 +239,23 @@ class CategoricalTest(PandasOnSparkTestCase, TestUtils):
)
pcser = pser.astype(CategoricalDtype(["c", "a", "b"]))
- kcser = psser.astype(CategoricalDtype(["c", "a", "b"]))
+ pscser = psser.astype(CategoricalDtype(["c", "a", "b"]))
- self.assert_eq(kcser.astype("category"), pcser.astype("category"))
+ self.assert_eq(pscser.astype("category"), pcser.astype("category"))
+ # CategoricalDtype is not updated if the dtype is same from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
- # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
- pass
- elif LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(
- kcser.astype(CategoricalDtype(["b", "c", "a"])),
+ pscser.astype(CategoricalDtype(["b", "c", "a"])),
pcser.astype(CategoricalDtype(["b", "c", "a"])),
)
else:
self.assert_eq(
- kcser.astype(CategoricalDtype(["b", "c", "a"])),
- pser.astype(CategoricalDtype(["b", "c", "a"])),
+ pscser.astype(CategoricalDtype(["b", "c", "a"])),
+ pcser,
)
- self.assert_eq(kcser.astype(str), pcser.astype(str))
+ self.assert_eq(pscser.astype(str), pcser.astype(str))
def test_factorize(self):
pser = pd.Series(["a", "b", "c", None], dtype=CategoricalDtype(["c", "a", "d", "b"]))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org