You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2021/08/16 18:07:48 UTC

[spark] branch branch-3.2 updated: [SPARK-36469][PYTHON] Implement Index.map

This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new cb14a32  [SPARK-36469][PYTHON] Implement Index.map
cb14a32 is described below

commit cb14a3200524ab757f7e68b35e08c43467c29d8b
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Mon Aug 16 11:06:10 2021 -0700

    [SPARK-36469][PYTHON] Implement Index.map
    
    ### What changes were proposed in this pull request?
    Implement `Index.map`.
    
    The PR is based on https://github.com/databricks/koalas/pull/2136. Thanks awdavidson for the prototype.
    
    `map` of CategoricalIndex and DatetimeIndex will be implemented in separate PRs.
    
    ### Why are the changes needed?
    Mapping values using input correspondence (a dict, Series, or function) is supported in pandas as [Index.map](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Index.map.html).
    We shall also support hat.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `Index.map` is available now.
    
    ```py
    >>> psidx = ps.Index([1, 2, 3])
    
    >>> psidx.map({1: "one", 2: "two", 3: "three"})
    Index(['one', 'two', 'three'], dtype='object')
    
    >>> psidx.map(lambda id: "{id} + 1".format(id=id))
    Index(['1 + 1', '2 + 1', '3 + 1'], dtype='object')
    
    >>> pser = pd.Series(["one", "two", "three"], index=[1, 2, 3])
    >>> psidx.map(pser)
    Index(['one', 'two', 'three'], dtype='object')
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #33694 from xinrong-databricks/index_map.
    
    Authored-by: Xinrong Meng <xi...@databricks.com>
    Signed-off-by: Takuya UESHIN <ue...@databricks.com>
    (cherry picked from commit 4dcd74602571d36a3b9129f0886e1cfc33d7fdc8)
    Signed-off-by: Takuya UESHIN <ue...@databricks.com>
---
 .../source/reference/pyspark.pandas/indexing.rst   |  1 +
 python/pyspark/pandas/indexes/base.py              | 46 +++++++++++++-
 python/pyspark/pandas/indexes/category.py          |  7 ++
 python/pyspark/pandas/indexes/datetimes.py         |  9 ++-
 python/pyspark/pandas/indexes/multi.py             |  7 ++
 python/pyspark/pandas/missing/indexes.py           |  2 +-
 python/pyspark/pandas/tests/indexes/test_base.py   | 74 ++++++++++++++++++++++
 7 files changed, 143 insertions(+), 3 deletions(-)

diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst
index 677d80f..9d53f00 100644
--- a/python/docs/source/reference/pyspark.pandas/indexing.rst
+++ b/python/docs/source/reference/pyspark.pandas/indexing.rst
@@ -64,6 +64,7 @@ Modifying and computations
    Index.drop_duplicates
    Index.min
    Index.max
+   Index.map
    Index.rename
    Index.repeat
    Index.take
diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py
index 6c842bc..a43a5d1 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -16,7 +16,7 @@
 #
 
 from functools import partial
-from typing import Any, Iterator, List, Optional, Tuple, Union, cast, no_type_check
+from typing import Any, Callable, Iterator, List, Optional, Tuple, Union, cast, no_type_check
 import warnings
 
 import pandas as pd
@@ -521,6 +521,50 @@ class Index(IndexOpsMixin):
             result = result.copy()
         return result
 
+    def map(
+        self, mapper: Union[dict, Callable[[Any], Any], pd.Series], na_action: Optional[str] = None
+    ) -> "Index":
+        """
+        Map values using input correspondence (a dict, Series, or function).
+
+        Parameters
+        ----------
+        mapper : function, dict, or pd.Series
+            Mapping correspondence.
+        na_action : {None, 'ignore'}
+            If ‘ignore’, propagate NA values, without passing them to the mapping correspondence.
+
+        Returns
+        -------
+        applied : Index, inferred
+            The output of the mapping function applied to the index.
+
+        Examples
+        --------
+        >>> psidx = ps.Index([1, 2, 3])
+
+        >>> psidx.map({1: "one", 2: "two", 3: "three"})
+        Index(['one', 'two', 'three'], dtype='object')
+
+        >>> psidx.map(lambda id: "{id} + 1".format(id=id))
+        Index(['1 + 1', '2 + 1', '3 + 1'], dtype='object')
+
+        >>> pser = pd.Series(["one", "two", "three"], index=[1, 2, 3])
+        >>> psidx.map(pser)
+        Index(['one', 'two', 'three'], dtype='object')
+        """
+        if isinstance(mapper, dict):
+            if len(set(type(k) for k in mapper.values())) > 1:
+                raise TypeError(
+                    "If the mapper is a dictionary, its values must be of the same type"
+                )
+
+        return Index(
+            self.to_series().pandas_on_spark.transform_batch(
+                lambda pser: pser.map(mapper, na_action)
+            )
+        ).rename(self.name)
+
     @property
     def values(self) -> np.ndarray:
         """
diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py
index e2dbd33..193c126 100644
--- a/python/pyspark/pandas/indexes/category.py
+++ b/python/pyspark/pandas/indexes/category.py
@@ -642,6 +642,13 @@ class CategoricalIndex(Index):
                 return partial(property_or_func, self)
         raise AttributeError("'CategoricalIndex' object has no attribute '{}'".format(item))
 
+    def map(
+        self,
+        mapper: Union[dict, Callable[[Any], Any], pd.Series] = None,
+        na_action: Optional[str] = None,
+    ) -> "Index":
+        return MissingPandasLikeCategoricalIndex.map(self, mapper, na_action)
+
 
 def _test() -> None:
     import os
diff --git a/python/pyspark/pandas/indexes/datetimes.py b/python/pyspark/pandas/indexes/datetimes.py
index 6998adf..691d8f9 100644
--- a/python/pyspark/pandas/indexes/datetimes.py
+++ b/python/pyspark/pandas/indexes/datetimes.py
@@ -16,7 +16,7 @@
 #
 import datetime
 from functools import partial
-from typing import Any, Optional, Union, cast, no_type_check
+from typing import Any, Callable, Optional, Union, cast, no_type_check
 
 import pandas as pd
 from pandas.api.types import is_hashable
@@ -741,6 +741,13 @@ class DatetimeIndex(Index):
             psdf = psdf.pandas_on_spark.apply_batch(pandas_at_time)
         return ps.Index(first_series(psdf).rename(self.name))
 
+    def map(
+        self,
+        mapper: Union[dict, Callable[[Any], Any], pd.Series] = None,
+        na_action: Optional[str] = None,
+    ) -> "Index":
+        return MissingPandasLikeDatetimeIndex.map(self, mapper, na_action)
+
 
 def disallow_nanoseconds(freq: Union[str, DateOffset]) -> None:
     if freq in ["N", "ns"]:
diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py
index 4b5ec04..fb02080 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -1165,6 +1165,13 @@ class MultiIndex(Index):
     def __iter__(self) -> Iterator:
         return MissingPandasLikeMultiIndex.__iter__(self)
 
+    def map(
+        self,
+        mapper: Union[dict, Callable[[Any], Any], pd.Series] = None,
+        na_action: Optional[str] = None,
+    ) -> "Index":
+        return MissingPandasLikeMultiIndex.map(self, mapper, na_action)
+
 
 def _test() -> None:
     import os
diff --git a/python/pyspark/pandas/missing/indexes.py b/python/pyspark/pandas/missing/indexes.py
index 938aea2..90e0c3e 100644
--- a/python/pyspark/pandas/missing/indexes.py
+++ b/python/pyspark/pandas/missing/indexes.py
@@ -58,7 +58,6 @@ class MissingPandasLikeIndex(object):
     is_ = _unsupported_function("is_")
     is_lexsorted_for_tuple = _unsupported_function("is_lexsorted_for_tuple")
     join = _unsupported_function("join")
-    map = _unsupported_function("map")
     putmask = _unsupported_function("putmask")
     ravel = _unsupported_function("ravel")
     reindex = _unsupported_function("reindex")
@@ -118,6 +117,7 @@ class MissingPandasLikeDatetimeIndex(MissingPandasLikeIndex):
     to_pydatetime = _unsupported_function("to_pydatetime", cls="DatetimeIndex")
     mean = _unsupported_function("mean", cls="DatetimeIndex")
     std = _unsupported_function("std", cls="DatetimeIndex")
+    map = _unsupported_function("map", cls="DatetimeIndex")
 
 
 class MissingPandasLikeCategoricalIndex(MissingPandasLikeIndex):
diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py
index 65831d1..fb1e4b2 100644
--- a/python/pyspark/pandas/tests/indexes/test_base.py
+++ b/python/pyspark/pandas/tests/indexes/test_base.py
@@ -2319,6 +2319,80 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
 
         self.assertRaises(PandasNotImplementedError, lambda: psmidx.factorize())
 
+    def test_map(self):
+        pidx = pd.Index([1, 2, 3])
+        psidx = ps.from_pandas(pidx)
+
+        # Apply dict
+        self.assert_eq(
+            pidx.map({1: "one", 2: "two", 3: "three"}),
+            psidx.map({1: "one", 2: "two", 3: "three"}),
+        )
+        self.assert_eq(
+            pidx.map({1: "one", 2: "two"}),
+            psidx.map({1: "one", 2: "two"}),
+        )
+        self.assert_eq(
+            pidx.map({1: "one", 2: "two"}, na_action="ignore"),
+            psidx.map({1: "one", 2: "two"}, na_action="ignore"),
+        )
+        self.assert_eq(
+            pidx.map({1: 10, 2: 20}),
+            psidx.map({1: 10, 2: 20}),
+        )
+        self.assert_eq(
+            (pidx + 1).map({1: 10, 2: 20}),
+            (psidx + 1).map({1: 10, 2: 20}),
+        )
+
+        # Apply lambda
+        self.assert_eq(
+            pidx.map(lambda id: id + 1),
+            psidx.map(lambda id: id + 1),
+        )
+        self.assert_eq(
+            pidx.map(lambda id: id + 1.1),
+            psidx.map(lambda id: id + 1.1),
+        )
+        self.assert_eq(
+            pidx.map(lambda id: "{id} + 1".format(id=id)),
+            psidx.map(lambda id: "{id} + 1".format(id=id)),
+        )
+        self.assert_eq(
+            (pidx + 1).map(lambda id: "{id} + 1".format(id=id)),
+            (psidx + 1).map(lambda id: "{id} + 1".format(id=id)),
+        )
+
+        # Apply series
+        pser = pd.Series(["one", "two", "three"], index=[1, 2, 3])
+        self.assert_eq(
+            pidx.map(pser),
+            psidx.map(pser),
+        )
+        pser = pd.Series(["one", "two", "three"])
+        self.assert_eq(
+            pidx.map(pser),
+            psidx.map(pser),
+        )
+        self.assert_eq(
+            pidx.map(pser, na_action="ignore"),
+            psidx.map(pser, na_action="ignore"),
+        )
+        pser = pd.Series([1, 2, 3])
+        self.assert_eq(
+            pidx.map(pser),
+            psidx.map(pser),
+        )
+        self.assert_eq(
+            (pidx + 1).map(pser),
+            (psidx + 1).map(pser),
+        )
+
+        self.assertRaises(
+            TypeError,
+            lambda: psidx.map({1: 1, 2: 2.0, 3: "three"}),
+        )
+
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.indexes.test_base import *  # noqa: F401

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org