You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/08/14 10:59:22 UTC

spark git commit: [SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize

Repository: spark
Updated Branches:
  refs/heads/master cdaa562c9 -> 0ebf7c1bf


[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize

## What changes were proposed in this pull request?

Replaces custom choose function with o.a.commons.math3.CombinatoricsUtils.binomialCoefficient

## How was this patch tested?

Spark unit tests

Author: zero323 <ze...@users.noreply.github.com>

Closes #14614 from zero323/SPARK-17027.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ebf7c1b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ebf7c1b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ebf7c1b

Branch: refs/heads/master
Commit: 0ebf7c1bff736cf54ec47957d71394d5b75b47a7
Parents: cdaa562
Author: zero323 <ze...@users.noreply.github.com>
Authored: Sun Aug 14 11:59:24 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Aug 14 11:59:24 2016 +0100

----------------------------------------------------------------------
 .../spark/ml/feature/PolynomialExpansion.scala  | 10 ++++----
 .../ml/feature/PolynomialExpansionSuite.scala   | 24 ++++++++++++++++++++
 2 files changed, 30 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0ebf7c1b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 72fb35b..6e872c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -19,6 +19,8 @@ package org.apache.spark.ml.feature
 
 import scala.collection.mutable
 
+import org.apache.commons.math3.util.CombinatoricsUtils
+
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.UnaryTransformer
 import org.apache.spark.ml.linalg._
@@ -84,12 +86,12 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
 @Since("1.6.0")
 object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
 
-  private def choose(n: Int, k: Int): Int = {
-    Range(n, n - k, -1).product / Range(k, 1, -1).product
+  private def getPolySize(numFeatures: Int, degree: Int): Int = {
+    val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree)
+    require(n <= Integer.MAX_VALUE)
+    n.toInt
   }
 
-  private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
-
   private def expandDense(
       values: Array[Double],
       lastIdx: Int,

http://git-wip-us.apache.org/repos/asf/spark/blob/0ebf7c1b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index 8e1f9dd..9ecd321 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -116,5 +116,29 @@ class PolynomialExpansionSuite
       .setDegree(3)
     testDefaultReadWrite(t)
   }
+
+  test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") {
+    val data: Array[(Vector, Int, Int)] = Array(
+      (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367),
+      (Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367),
+      (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375)
+    )
+
+    val df = spark.createDataFrame(data)
+      .toDF("features", "expectedPoly10size", "expectedPoly11size")
+
+    val t = new PolynomialExpansion()
+      .setInputCol("features")
+      .setOutputCol("polyFeatures")
+
+    for (i <- Seq(10, 11)) {
+      val transformed = t.setDegree(i)
+        .transform(df)
+        .select(s"expectedPoly${i}size", "polyFeatures")
+        .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
+
+      assert(transformed.collect.forall(identity))
+    }
+  }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org