You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/19 22:53:13 UTC

spark git commit: [SPARK-7652] [MLLIB] Update the implementation of naive Bayes prediction with BLAS

Repository: spark
Updated Branches:
  refs/heads/master 68fb2a46e -> c12dff9b8


[SPARK-7652] [MLLIB] Update the implementation of naive Bayes prediction with BLAS

JIRA: https://issues.apache.org/jira/browse/SPARK-7652

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #6189 from viirya/naive_bayes_blas_prediction and squashes the following commits:

ab611fd [Liang-Chi Hsieh] Remove unnecessary space.
ddc48b9 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into naive_bayes_blas_prediction
b5772b4 [Liang-Chi Hsieh] Fix binary compatibility.
2f65186 [Liang-Chi Hsieh] Remove toDense.
1b6cdfe [Liang-Chi Hsieh] Update the implementation of naive Bayes prediction with BLAS.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c12dff9b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c12dff9b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c12dff9b

Branch: refs/heads/master
Commit: c12dff9b82e4869f866a9b96ce0bf05503dd7dda
Parents: 68fb2a4
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Tue May 19 13:53:08 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 19 13:53:08 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala | 41 ++++++++++++--------
 1 file changed, 24 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c12dff9b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index ac0ebec..53fb2cb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable}
 
 import scala.collection.JavaConverters._
 
-import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import breeze.numerics.{exp => brzExp, log => brzLog}
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{Logging, SparkContext, SparkException}
-import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
+import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
@@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] (
     val modelType: String)
   extends ClassificationModel with Serializable with Saveable {
 
+  private val piVector = new DenseVector(pi)
+  private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
+
   private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
     this(labels, pi, theta, "Multinomial")
 
@@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] (
       theta: JIterable[JIterable[Double]]) =
     this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
 
-  private val brzPi = new BDV[Double](pi)
-  private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
-
   // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
-  // This precomputes log(1.0 - exp(theta)) and its sum  which are used for the  linear algebra
+  // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
   // application of this condition (in predict function).
-  private val (brzNegTheta, brzNegThetaSum) = modelType match {
+  private val (thetaMinusNegTheta, negThetaSum) = modelType match {
     case "Multinomial" => (None, None)
     case "Bernoulli" =>
-      val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
-      (Option(negTheta), Option(brzSum(negTheta, Axis._1)))
+      val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
+      val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
+      val thetaMinusNegTheta = thetaMatrix.map { value =>
+        value - math.log(1.0 - math.exp(value))
+      }
+      (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
     case _ =>
       // This should never happen.
       throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] (
   }
 
   override def predict(testData: Vector): Double = {
-    val brzData = testData.toBreeze
     modelType match {
       case "Multinomial" =>
-        labels(brzArgmax(brzPi + brzTheta * brzData))
+        val prob = thetaMatrix.multiply(testData)
+        BLAS.axpy(1.0, piVector, prob)
+        labels(prob.argmax)
       case "Bernoulli" =>
-        if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
-          throw new SparkException(
-            s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+        testData.foreachActive { (index, value) =>
+          if (value != 0.0 && value != 1.0) {
+            throw new SparkException(
+              s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+          }
         }
-        labels(brzArgmax(brzPi +
-          (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
+        val prob = thetaMinusNegTheta.get.multiply(testData)
+        BLAS.axpy(1.0, piVector, prob)
+        BLAS.axpy(1.0, negThetaSum.get, prob)
+        labels(prob.argmax)
       case _ =>
         // This should never happen.
         throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org