You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/06/12 21:27:19 UTC

spark git commit: [SPARK-21050][ML] Word2vec persistence overflow bug fix

Repository: spark
Updated Branches:
  refs/heads/master b1436c749 -> ff318c0d2


[SPARK-21050][ML] Word2vec persistence overflow bug fix

## What changes were proposed in this pull request?

The method calculateNumberOfPartitions() uses Int, not Long (unlike the MLlib version), so it is very easily to have an overflow in calculating the number of partitions for ML persistence.

This modifies the calculations to use Long.

## How was this patch tested?

New unit test.  I verified that the test fails before this patch.

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #18265 from jkbradley/word2vec-save-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ff318c0d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ff318c0d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ff318c0d

Branch: refs/heads/master
Commit: ff318c0d2f283c3f46491f229f82d93714da40c7
Parents: b1436c7
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Mon Jun 12 14:27:57 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jun 12 14:27:57 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 38 ++++++++++++++------
 .../apache/spark/ml/feature/Word2VecSuite.scala | 10 ++++++
 2 files changed, 38 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ff318c0d/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 4ca062c..b6909b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
 
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
@@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
       val wordVectors = instance.wordVectors.getVectors
       val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) }
       val dataPath = new Path(path, "data").toString
+      val bufferSizeInBytes = Utils.byteStringAsBytes(
+        sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
+      val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
+        bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize)
       sparkSession.createDataFrame(dataSeq)
-        .repartition(calculateNumberOfPartitions)
+        .repartition(numPartitions)
         .write
         .parquet(dataPath)
     }
+  }
 
-    def calculateNumberOfPartitions(): Int = {
-      val floatSize = 4
+  private[feature]
+  object Word2VecModelWriter {
+    /**
+     * Calculate the number of partitions to use in saving the model.
+     * [SPARK-11994] - We want to partition the model in partitions smaller than
+     * spark.kryoserializer.buffer.max
+     * @param bufferSizeInBytes  Set to spark.kryoserializer.buffer.max
+     * @param numWords  Vocab size
+     * @param vectorSize  Vector length for each word
+     */
+    def calculateNumberOfPartitions(
+        bufferSizeInBytes: Long,
+        numWords: Int,
+        vectorSize: Int): Int = {
+      val floatSize = 4L  // Use Long to help avoid overflow
       val averageWordSize = 15
-      // [SPARK-11994] - We want to partition the model in partitions smaller than
-      // spark.kryoserializer.buffer.max
-      val bufferSizeInBytes = Utils.byteStringAsBytes(
-        sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
       // Calculate the approximate size of the model.
       // Assuming an average word size of 15 bytes, the formula is:
       // (floatSize * vectorSize + 15) * numWords
-      val numWords = instance.wordVectors.wordIndex.size
-      val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords
-      ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt
+      val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords
+      val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1
+      require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " +
+        s"partitions to save this model, which is too large.  Try increasing " +
+        s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.")
+      numPartitions.toInt
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ff318c0d/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a6a1c2b..6183606 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
+import org.apache.spark.util.Utils
 
 class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
@@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
     assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
   }
 
+  test("Word2Vec read/write numPartitions calculation") {
+    val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
+      Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5)
+    assert(smallModelNumPartitions === 1)
+    val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
+      Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000)
+    assert(largeModelNumPartitions > 1)
+  }
+
   test("Word2Vec read/write") {
     val t = new Word2Vec()
       .setInputCol("myInputCol")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org