You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/31 22:11:46 UTC

spark git commit: [SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities

Repository: spark
Updated Branches:
  refs/heads/master 060c79aab -> fbef566a1


[SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities

Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel.

Author: Yanbo Liang <yb...@gmail.com>

Closes #7672 from yanboliang/spark-9308 and squashes the following commits:

25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place
3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace
c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fbef566a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fbef566a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fbef566a

Branch: refs/heads/master
Commit: fbef566a107b47e5fddde0ea65b8587d5039062d
Parents: 060c79a
Author: Yanbo Liang <yb...@gmail.com>
Authored: Fri Jul 31 13:11:42 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Jul 31 13:11:42 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/NaiveBayes.scala    | 65 +++++++++++++++-----
 .../ml/classification/NaiveBayesSuite.scala     | 54 +++++++++++++++-
 2 files changed, 101 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fbef566a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 5be35fe..b46b676 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
  * The input feature values must be nonnegative.
  */
 class NaiveBayes(override val uid: String)
-  extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
+  extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
   with NaiveBayesParams {
 
   def this() = this(Identifiable.randomUID("nb"))
@@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] (
     override val uid: String,
     val pi: Vector,
     val theta: Matrix)
-  extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+  extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
 
   import OldNaiveBayes.{Bernoulli, Multinomial}
 
@@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] (
       throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
   }
 
-  override protected def predict(features: Vector): Double = {
+  override val numClasses: Int = pi.size
+
+  private def multinomialCalculation(features: Vector) = {
+    val prob = theta.multiply(features)
+    BLAS.axpy(1.0, pi, prob)
+    prob
+  }
+
+  private def bernoulliCalculation(features: Vector) = {
+    features.foreachActive((_, value) =>
+      if (value != 0.0 && value != 1.0) {
+        throw new SparkException(
+          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
+      }
+    )
+    val prob = thetaMinusNegTheta.get.multiply(features)
+    BLAS.axpy(1.0, pi, prob)
+    BLAS.axpy(1.0, negThetaSum.get, prob)
+    prob
+  }
+
+  override protected def predictRaw(features: Vector): Vector = {
     $(modelType) match {
       case Multinomial =>
-        val prob = theta.multiply(features)
-        BLAS.axpy(1.0, pi, prob)
-        prob.argmax
+        multinomialCalculation(features)
       case Bernoulli =>
-        features.foreachActive{ (index, value) =>
-          if (value != 0.0 && value != 1.0) {
-            throw new SparkException(
-              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
-          }
-        }
-        val prob = thetaMinusNegTheta.get.multiply(features)
-        BLAS.axpy(1.0, pi, prob)
-        BLAS.axpy(1.0, negThetaSum.get, prob)
-        prob.argmax
+        bernoulliCalculation(features)
       case _ =>
         // This should never happen.
         throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
     }
   }
 
+  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+    rawPrediction match {
+      case dv: DenseVector =>
+        var i = 0
+        val size = dv.size
+        val maxLog = dv.values.max
+        while (i < size) {
+          dv.values(i) = math.exp(dv.values(i) - maxLog)
+          i += 1
+        }
+        val probSum = dv.values.sum
+        i = 0
+        while (i < size) {
+          dv.values(i) = dv.values(i) / probSum
+          i += 1
+        }
+        dv
+      case sv: SparseVector =>
+        throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
+          " raw2probabilityInPlace encountered SparseVector")
+    }
+  }
+
   override def copy(extra: ParamMap): NaiveBayesModel = {
     copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fbef566a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 264bde3..aea3d9b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.ml.classification
 
+import breeze.linalg.{Vector => BV}
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.classification.NaiveBayes
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
@@ -28,6 +31,8 @@ import org.apache.spark.sql.Row
 
 class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
 
+  import NaiveBayes.{Multinomial, Bernoulli}
+
   def validatePrediction(predictionAndLabels: DataFrame): Unit = {
     val numOfErrorPredictions = predictionAndLabels.collect().count {
       case Row(prediction: Double, label: Double) =>
@@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
   }
 
+  def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+    val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+    val classProbs = logClassProbs.toArray.map(math.exp)
+    val classProbsSum = classProbs.sum
+    Vectors.dense(classProbs.map(_ / classProbsSum))
+  }
+
+  def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+    val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
+    val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
+    val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+    val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
+    val classProbs = logClassProbs.toArray.map(math.exp)
+    val classProbsSum = classProbs.sum
+    Vectors.dense(classProbs.map(_ / classProbsSum))
+  }
+
+  def validateProbabilities(
+      featureAndProbabilities: DataFrame,
+      model: NaiveBayesModel,
+      modelType: String): Unit = {
+    featureAndProbabilities.collect().foreach {
+      case Row(features: Vector, probability: Vector) => {
+        assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
+        val expected = modelType match {
+          case Multinomial =>
+            expectedMultinomialProbabilities(model, features)
+          case Bernoulli =>
+            expectedBernoulliProbabilities(model, features)
+          case _ =>
+            throw new UnknownError(s"Invalid modelType: $modelType.")
+        }
+        assert(probability ~== expected relTol 1.0e-10)
+      }
+    }
+  }
+
   test("params") {
     ParamsSuite.checkParams(new NaiveBayes)
     val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
 
     val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
       piArray, thetaArray, nPoints, 17, "multinomial"))
-    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
 
+    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
     validatePrediction(predictionAndLabels)
+
+    val featureAndProbabilities = model.transform(validationDataset)
+      .select("features", "probability")
+    validateProbabilities(featureAndProbabilities, model, "multinomial")
   }
 
   test("Naive Bayes Bernoulli") {
@@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
 
     val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
       piArray, thetaArray, nPoints, 20, "bernoulli"))
-    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
 
+    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
     validatePrediction(predictionAndLabels)
+
+    val featureAndProbabilities = model.transform(validationDataset)
+      .select("features", "probability")
+    validateProbabilities(featureAndProbabilities, model, "bernoulli")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org