You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/01/05 01:29:41 UTC
[3/9] git commit: Let reduceByKey to take care of local combine
Let reduceByKey to take care of local combine
Also refactored some heavy FP code to improve readability and reduce memory footprint.
Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/c0337c5b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/c0337c5b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/c0337c5b
Branch: refs/heads/master
Commit: c0337c5bbfd5126c64964a9fdefd2bef11727d87
Parents: 3bb714e
Author: Lian, Cheng <rh...@gmail.com>
Authored: Wed Dec 25 22:45:57 2013 +0800
Committer: Lian, Cheng <rh...@gmail.com>
Committed: Wed Dec 25 22:45:57 2013 +0800
----------------------------------------------------------------------
.../spark/mllib/classification/NaiveBayes.scala | 43 ++++++++------------
1 file changed, 16 insertions(+), 27 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/c0337c5b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index edea5ed..4c96b24 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.classification
-import scala.collection.mutable
-
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -63,39 +61,30 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
* @param data RDD of (label, array of features) pairs.
*/
def run(C: Int, D: Int, data: RDD[LabeledPoint]) = {
- val locallyReduced = data.mapPartitions { iterator =>
- val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0)
- val localSummedObservations =
- mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0))
-
- for (LabeledPoint(label, features) <- iterator; i = label.toInt) {
- localLabelCounts(i) += 1
- localSummedObservations(i) = vectorAdd(localSummedObservations(i), features)
- }
-
- for ((label, count) <- localLabelCounts.toIterator) yield {
- label -> (count, localSummedObservations(label))
- }
- }
-
- val reduced = locallyReduced.reduceByKey { (lhs, rhs) =>
+ val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) =>
+ label.toInt ->(1, features)
+ }.reduceByKey { (lhs, rhs) =>
(lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2))
}
- val collected = reduced.mapValues { case (count, summed) =>
+ val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) =>
val labelWeight = math.log(count + lambda)
- val logDenom = math.log(summed.sum + D * lambda)
- val weights = summed.map(w => math.log(w + lambda) - logDenom)
+ val logDenom = math.log(summedFeatureVector.sum + D * lambda)
+ val weights = summedFeatureVector.map(w => math.log(w + lambda) - logDenom)
(count, labelWeight, weights)
}.collectAsMap()
- val weightPerLabel = {
- val N = collected.values.map(_._1).sum
- val logDenom = math.log(N + C * lambda)
- collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2)
- }
+ // We can simply call `data.count` to get `N`, but that triggers another RDD action, which is
+ // considerably expensive.
+ val N = collected.values.map(_._1).sum
+ val logDenom = math.log(N + C * lambda)
+ val weightPerLabel = Array.fill[Double](C)(0)
+ val weightMatrix = Array.fill[Array[Double]](C)(null)
- val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2)
+ for ((label, (_, labelWeight, weights)) <- collected) {
+ weightPerLabel(label) = labelWeight - logDenom
+ weightMatrix(label) = weights
+ }
new NaiveBayesModel(weightPerLabel, weightMatrix)
}