You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/05/14 14:28:25 UTC

[spark] branch branch-2.4 updated: [SPARK-31676][ML] QuantileDiscretizer raise error parameter splits given invalid value (splits array includes -0.0 and 0.0)

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 1ea5844  [SPARK-31676][ML] QuantileDiscretizer raise error parameter splits given invalid value (splits array includes -0.0 and 0.0)
1ea5844 is described below

commit 1ea584443e9372a6a0b3c8449f5bf7e9e1369b0d
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Thu May 14 09:24:40 2020 -0500

    [SPARK-31676][ML] QuantileDiscretizer raise error parameter splits given invalid value (splits array includes -0.0 and 0.0)
    
    In QuantileDiscretizer.getDistinctSplits, before invoking distinct, normalize all -0.0 and 0.0 to be 0.0
    ```
        for (i <- 0 until splits.length) {
          if (splits(i) == -0.0) {
            splits(i) = 0.0
          }
        }
    ```
    Fix bug.
    
    No
    
    Unit test.
    
    ~~~scala
    import scala.util.Random
    val rng = new Random(3)
    
    val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
    
    import spark.implicits._
    val df1 = sc.parallelize(a1, 2).toDF("id")
    
    import org.apache.spark.ml.feature.QuantileDiscretizer
    val qd = new QuantileDiscretizer().setInputCol("id").setOutputCol("out").setNumBuckets(200).setRelativeError(0.0)
    
    val model = qd.fit(df1) // will raise error in spark master.
    ~~~
    
    scala `0.0 == -0.0` is True but `0.0.hashCode == -0.0.hashCode()` is False. This break the contract between equals() and hashCode() If two objects are equal, then they must have the same hash code.
    
    And array.distinct will rely on elem.hashCode so it leads to this error.
    
    Test code on distinct
    ```
    import scala.util.Random
    val rng = new Random(3)
    
    val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
    a1.distinct.sorted.foreach(x => print(x.toString + "\n"))
    ```
    
    Then you will see output like:
    ```
    ...
    -0.009292684662246975
    -0.0033280686465135823
    -0.0
    0.0
    0.0022219556032221366
    0.02217419561977274
    ...
    ```
    
    Closes #28498 from WeichenXu123/SPARK-31676.
    
    Authored-by: Weichen Xu <we...@databricks.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
    (cherry picked from commit b2300fca1e1a22d74c6eeda37942920a6c6299ff)
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../apache/spark/ml/feature/QuantileDiscretizer.scala  | 12 ++++++++++++
 .../spark/ml/feature/QuantileDiscretizerSuite.scala    | 18 ++++++++++++++++++
 2 files changed, 30 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 56e2c54..f3ec358 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -243,6 +243,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
   private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
     splits(0) = Double.NegativeInfinity
     splits(splits.length - 1) = Double.PositiveInfinity
+
+    // 0.0 and -0.0 are distinct values, array.distinct will preserve both of them.
+    // but 0.0 > -0.0 is False which will break the parameter validation checking.
+    // and in scala <= 2.12, there's bug which will cause array.distinct generate
+    // non-deterministic results when array contains both 0.0 and -0.0
+    // So that here we should first normalize all 0.0 and -0.0 to be 0.0
+    // See https://github.com/scala/bug/issues/11995
+    for (i <- 0 until splits.length) {
+      if (splits(i) == -0.0) {
+        splits(i) = 0.0
+      }
+    }
     val distinctSplits = splits.distinct
     if (splits.length != distinctSplits.length) {
       log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b009038..9c37416 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -443,4 +443,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
       discretizer.fit(df)
     }
   }
+
+  test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given invalid value") {
+    import scala.util.Random
+    val rng = new Random(3)
+
+    val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++
+      Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
+
+    val df1 = sc.parallelize(a1, 2).toDF("id")
+
+    val qd = new QuantileDiscretizer()
+      .setInputCol("id")
+      .setOutputCol("out")
+      .setNumBuckets(200)
+      .setRelativeError(0.0)
+
+    qd.fit(df1) // assert no exception raised here.
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org