You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/08/10 00:42:27 UTC

spark git commit: [SPARK-14932][SQL] Allow DataFrame.replace() to replace values with None

Repository: spark
Updated Branches:
  refs/heads/master c06f3f5ac -> 84454d7d3


[SPARK-14932][SQL] Allow DataFrame.replace() to replace values with None

## What changes were proposed in this pull request?

Currently `df.na.replace("*", Map[String, String]("NULL" -> null))` will produce exception.
This PR enables passing null/None as value in the replacement map in DataFrame.replace().
Note that the replacement map keys and values should still be the same type, while the values can have a mix of null/None and that type.
This PR enables following operations for example:
`df.na.replace("*", Map[String, String]("NULL" -> null))`(scala)
`df.na.replace("*", Map[Any, Any](60 -> null, 70 -> 80))`(scala)
`df.na.replace('Alice', None)`(python)
`df.na.replace([10, 20])`(python, replacing with None is by default)
One use case could be: I want to replace all the empty strings with null/None because they were incorrectly generated and then drop all null/None data
`df.na.replace("*", Map("" -> null)).na.drop()`(scala)
`df.replace(u'', None).dropna()`(python)

## How was this patch tested?

Scala unit test.
Python doctest and unit test.

Author: bravo-zhang <mz...@gmail.com>

Closes #18820 from bravo-zhang/spark-14932.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84454d7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84454d7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84454d7d

Branch: refs/heads/master
Commit: 84454d7d33363a41adf242c8a81ffca20769c55c
Parents: c06f3f5
Author: bravo-zhang <mz...@gmail.com>
Authored: Wed Aug 9 17:42:21 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Wed Aug 9 17:42:21 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 35 +++++++-----
 python/pyspark/sql/tests.py                     | 15 ++++++
 .../apache/spark/sql/DataFrameNaFunctions.scala | 57 +++++++++++---------
 .../spark/sql/DataFrameNaFunctionsSuite.scala   | 43 +++++++++++++++
 4 files changed, 113 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84454d7d/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 944739b..edc7ca6 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1362,8 +1362,8 @@ class DataFrame(object):
         """Returns a new :class:`DataFrame` replacing a value with another value.
         :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
         aliases of each other.
-        Values to_replace and value should contain either all numerics, all booleans,
-        or all strings. When replacing, the new value will be cast
+        Values to_replace and value must have the same type and can only be numerics, booleans,
+        or strings. Value can have None. When replacing, the new value will be cast
         to the type of the existing column.
         For numeric replacements all values to be replaced should have unique
         floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`)
@@ -1373,8 +1373,8 @@ class DataFrame(object):
             Value to be replaced.
             If the value is a dict, then `value` is ignored and `to_replace` must be a
             mapping between a value and a replacement.
-        :param value: int, long, float, string, or list.
-            The replacement value must be an int, long, float, or string. If `value` is a
+        :param value: bool, int, long, float, string, list or None.
+            The replacement value must be a bool, int, long, float, string or None. If `value` is a
             list, `value` should be of the same length and type as `to_replace`.
             If `value` is a scalar and `to_replace` is a sequence, then `value` is
             used as a replacement for each item in `to_replace`.
@@ -1393,6 +1393,16 @@ class DataFrame(object):
         |null|  null| null|
         +----+------+-----+
 
+        >>> df4.na.replace('Alice', None).show()
+        +----+------+----+
+        | age|height|name|
+        +----+------+----+
+        |  10|    80|null|
+        |   5|  null| Bob|
+        |null|  null| Tom|
+        |null|  null|null|
+        +----+------+----+
+
         >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
         +----+------+----+
         | age|height|name|
@@ -1425,12 +1435,13 @@ class DataFrame(object):
         valid_types = (bool, float, int, long, basestring, list, tuple)
         if not isinstance(to_replace, valid_types + (dict, )):
             raise ValueError(
-                "to_replace should be a float, int, long, string, list, tuple, or dict. "
+                "to_replace should be a bool, float, int, long, string, list, tuple, or dict. "
                 "Got {0}".format(type(to_replace)))
 
-        if not isinstance(value, valid_types) and not isinstance(to_replace, dict):
+        if not isinstance(value, valid_types) and value is not None \
+                and not isinstance(to_replace, dict):
             raise ValueError("If to_replace is not a dict, value should be "
-                             "a float, int, long, string, list, or tuple. "
+                             "a bool, float, int, long, string, list, tuple or None. "
                              "Got {0}".format(type(value)))
 
         if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
@@ -1446,21 +1457,21 @@ class DataFrame(object):
         if isinstance(to_replace, (float, int, long, basestring)):
             to_replace = [to_replace]
 
-        if isinstance(value, (float, int, long, basestring)):
-            value = [value for _ in range(len(to_replace))]
-
         if isinstance(to_replace, dict):
             rep_dict = to_replace
             if value is not None:
                 warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
         else:
+            if isinstance(value, (float, int, long, basestring)) or value is None:
+                value = [value for _ in range(len(to_replace))]
             rep_dict = dict(zip(to_replace, value))
 
         if isinstance(subset, basestring):
             subset = [subset]
 
-        # Verify we were not passed in mixed type generics."
-        if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values())
+        # Verify we were not passed in mixed type generics.
+        if not any(all_of_type(rep_dict.keys())
+                   and all_of_type(x for x in rep_dict.values() if x is not None)
                    for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):
             raise ValueError("Mixed type replacements are not supported")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/84454d7d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index cfd9c55..cf2c473 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1964,6 +1964,21 @@ class SQLTests(ReusedPySparkTestCase):
                .replace(False, True).first())
         self.assertTupleEqual(row, (True, True))
 
+        # replace list while value is not given (default to None)
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
+        self.assertTupleEqual(row, (None, 10, 80.0))
+
+        # replace string with None and then drop None rows
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
+        self.assertEqual(row.count(), 0)
+
+        # replace with number and None
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
+        self.assertTupleEqual(row, (u'Alice', 20, None))
+
         # should fail if subset is not list, tuple or None
         with self.assertRaises(ValueError):
             self.spark.createDataFrame(

http://git-wip-us.apache.org/repos/asf/spark/blob/84454d7d/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 871fff7..e068df3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -260,9 +260,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
 
   /**
    * Replaces values matching keys in `replacement` map with the corresponding values.
-   * Key and value of `replacement` map must have the same type, and
-   * can only be doubles, strings or booleans.
-   * If `col` is "*", then the replacement is applied on all string columns or numeric columns.
    *
    * {{{
    *   import com.google.common.collect.ImmutableMap;
@@ -277,8 +274,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *   df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
    * }}}
    *
-   * @param col name of the column to apply the value replacement
-   * @param replacement value replacement map, as explained above
+   * @param col name of the column to apply the value replacement. If `col` is "*",
+   *            replacement is applied on all string, numeric or boolean columns.
+   * @param replacement value replacement map. Key and value of `replacement` map must have
+   *                    the same type, and can only be doubles, strings or booleans.
+   *                    The map value can have nulls.
    *
    * @since 1.3.1
    */
@@ -288,8 +288,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
 
   /**
    * Replaces values matching keys in `replacement` map with the corresponding values.
-   * Key and value of `replacement` map must have the same type, and
-   * can only be doubles, strings or booleans.
    *
    * {{{
    *   import com.google.common.collect.ImmutableMap;
@@ -301,8 +299,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *   df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
    * }}}
    *
-   * @param cols list of columns to apply the value replacement
-   * @param replacement value replacement map, as explained above
+   * @param cols list of columns to apply the value replacement. If `col` is "*",
+   *             replacement is applied on all string, numeric or boolean columns.
+   * @param replacement value replacement map. Key and value of `replacement` map must have
+   *                    the same type, and can only be doubles, strings or booleans.
+   *                    The map value can have nulls.
    *
    * @since 1.3.1
    */
@@ -312,10 +313,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
 
   /**
    * (Scala-specific) Replaces values matching keys in `replacement` map.
-   * Key and value of `replacement` map must have the same type, and
-   * can only be doubles, strings or booleans.
-   * If `col` is "*",
-   * then the replacement is applied on all string columns , numeric columns or boolean columns.
    *
    * {{{
    *   // Replaces all occurrences of 1.0 with 2.0 in column "height".
@@ -328,8 +325,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *   df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
    * }}}
    *
-   * @param col name of the column to apply the value replacement
-   * @param replacement value replacement map, as explained above
+   * @param col name of the column to apply the value replacement. If `col` is "*",
+   *            replacement is applied on all string, numeric or boolean columns.
+   * @param replacement value replacement map. Key and value of `replacement` map must have
+   *                    the same type, and can only be doubles, strings or booleans.
+   *                    The map value can have nulls.
    *
    * @since 1.3.1
    */
@@ -343,8 +343,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
 
   /**
    * (Scala-specific) Replaces values matching keys in `replacement` map.
-   * Key and value of `replacement` map must have the same type, and
-   * can only be doubles , strings or booleans.
    *
    * {{{
    *   // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
@@ -354,8 +352,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *   df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"));
    * }}}
    *
-   * @param cols list of columns to apply the value replacement
-   * @param replacement value replacement map, as explained above
+   * @param cols list of columns to apply the value replacement. If `col` is "*",
+   *             replacement is applied on all string, numeric or boolean columns.
+   * @param replacement value replacement map. Key and value of `replacement` map must have
+   *                    the same type, and can only be doubles, strings or booleans.
+   *                    The map value can have nulls.
    *
    * @since 1.3.1
    */
@@ -366,14 +367,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
       return df
     }
 
-    // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean]
-    val replacementMap: Map[_, _] = replacement.head._2 match {
-      case v: String => replacement
-      case v: Boolean => replacement
-      case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) }
+    // Convert the NumericType in replacement map to DoubleType,
+    // while leaving StringType, BooleanType and null untouched.
+    val replacementMap: Map[_, _] = replacement.map {
+      case (k, v: String) => (k, v)
+      case (k, v: Boolean) => (k, v)
+      case (k: String, null) => (k, null)
+      case (k: Boolean, null) => (k, null)
+      case (k, null) => (convertToDouble(k), null)
+      case (k, v) => (convertToDouble(k), convertToDouble(v))
     }
 
-    // targetColumnType is either DoubleType or StringType or BooleanType
+    // targetColumnType is either DoubleType, StringType or BooleanType,
+    // depending on the type of first key in replacement map.
+    // Only fields of targetColumnType will perform replacement.
     val targetColumnType = replacement.head._1 match {
       case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType
       case _: jl.Boolean => BooleanType

http://git-wip-us.apache.org/repos/asf/spark/blob/84454d7d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 47c9ba5..e6983b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -262,4 +262,47 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
     assert(out1(4) === Row("Amy", null, null))
     assert(out1(5) === Row(null, null, null))
   }
+
+  test("replace with null") {
+    val input = Seq[(String, java.lang.Double, java.lang.Boolean)](
+      ("Bob", 176.5, true),
+      ("Alice", 164.3, false),
+      ("David", null, true)
+    ).toDF("name", "height", "married")
+
+    // Replace String with String and null
+    checkAnswer(
+      input.na.replace("name", Map(
+        "Bob" -> "Bravo",
+        "Alice" -> null
+      )),
+      Row("Bravo", 176.5, true) ::
+        Row(null, 164.3, false) ::
+        Row("David", null, true) :: Nil)
+
+    // Replace Double with null
+    checkAnswer(
+      input.na.replace("height", Map[Any, Any](
+        164.3 -> null
+      )),
+      Row("Bob", 176.5, true) ::
+        Row("Alice", null, false) ::
+        Row("David", null, true) :: Nil)
+
+    // Replace Boolean with null
+    checkAnswer(
+      input.na.replace("*", Map[Any, Any](
+        false -> null
+      )),
+      Row("Bob", 176.5, true) ::
+        Row("Alice", 164.3, null) ::
+        Row("David", null, true) :: Nil)
+
+    // Replace String with null and then drop rows containing null
+    checkAnswer(
+      input.na.replace("name", Map(
+        "Bob" -> null
+      )).na.drop("name" :: Nil).select("name"),
+      Row("Alice") :: Row("David") :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org