You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/01/05 01:29:41 UTC

[3/9] git commit: Let reduceByKey to take care of local combine

Let reduceByKey to take care of local combine

Also refactored some heavy FP code to improve readability and reduce memory footprint.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/c0337c5b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/c0337c5b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/c0337c5b

Branch: refs/heads/master
Commit: c0337c5bbfd5126c64964a9fdefd2bef11727d87
Parents: 3bb714e
Author: Lian, Cheng <rh...@gmail.com>
Authored: Wed Dec 25 22:45:57 2013 +0800
Committer: Lian, Cheng <rh...@gmail.com>
Committed: Wed Dec 25 22:45:57 2013 +0800

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala | 43 ++++++++------------
 1 file changed, 16 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/c0337c5b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index edea5ed..4c96b24 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.mllib.classification
 
-import scala.collection.mutable
-
 import org.apache.spark.Logging
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
@@ -63,39 +61,30 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
    * @param data RDD of (label, array of features) pairs.
    */
   def run(C: Int, D: Int, data: RDD[LabeledPoint]) = {
-    val locallyReduced = data.mapPartitions { iterator =>
-      val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0)
-      val localSummedObservations =
-        mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0))
-
-      for (LabeledPoint(label, features) <- iterator; i = label.toInt) {
-        localLabelCounts(i) += 1
-        localSummedObservations(i) = vectorAdd(localSummedObservations(i), features)
-      }
-
-      for ((label, count) <- localLabelCounts.toIterator) yield {
-        label -> (count, localSummedObservations(label))
-      }
-    }
-
-    val reduced = locallyReduced.reduceByKey { (lhs, rhs) =>
+    val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) =>
+      label.toInt ->(1, features)
+    }.reduceByKey { (lhs, rhs) =>
       (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2))
     }
 
-    val collected = reduced.mapValues { case (count, summed) =>
+    val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) =>
       val labelWeight = math.log(count + lambda)
-      val logDenom = math.log(summed.sum + D * lambda)
-      val weights = summed.map(w => math.log(w + lambda) - logDenom)
+      val logDenom = math.log(summedFeatureVector.sum + D * lambda)
+      val weights = summedFeatureVector.map(w => math.log(w + lambda) - logDenom)
       (count, labelWeight, weights)
     }.collectAsMap()
 
-    val weightPerLabel = {
-      val N = collected.values.map(_._1).sum
-      val logDenom = math.log(N + C * lambda)
-      collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2)
-    }
+    // We can simply call `data.count` to get `N`, but that triggers another RDD action, which is
+    // considerably expensive.
+    val N = collected.values.map(_._1).sum
+    val logDenom = math.log(N + C * lambda)
+    val weightPerLabel = Array.fill[Double](C)(0)
+    val weightMatrix = Array.fill[Array[Double]](C)(null)
 
-    val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2)
+    for ((label, (_, labelWeight, weights)) <- collected) {
+      weightPerLabel(label) = labelWeight - logDenom
+      weightMatrix(label) = weights
+    }
 
     new NaiveBayesModel(weightPerLabel, weightMatrix)
   }