You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/05/13 23:13:25 UTC

spark git commit: [SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that both training and predict features have values of 0 or 1

Repository: spark
Updated Branches:
  refs/heads/master 5db18ba6e -> 61e05fc58


[SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that both training and predict features have values of 0 or 1

Author: leahmcguire <lm...@salesforce.com>

Closes #6073 from leahmcguire/binaryCheckNB and squashes the following commits:

b8442c2 [leahmcguire] changed to if else for value checks
911bf83 [leahmcguire] undid reformat
4eedf1e [leahmcguire] moved bernoulli check
9ee9e84 [leahmcguire] fixed style error
3f3b32c [leahmcguire] fixed zero one check so only called in combiner
831fd27 [leahmcguire] got test working
f44bb3c [leahmcguire] removed changes from CV branch
67253f0 [leahmcguire] added check to bernoulli to ensure feature values are zero or one
f191c71 [leahmcguire] fixed name
58d060b [leahmcguire] changed param name and test according to comments
04f0d3c [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/61e05fc5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/61e05fc5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/61e05fc5

Branch: refs/heads/master
Commit: 61e05fc58e1245de871c409b60951745b5db3420
Parents: 5db18ba
Author: leahmcguire <lm...@salesforce.com>
Authored: Wed May 13 14:13:19 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed May 13 14:13:19 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala | 28 +++++++++++++++--
 .../mllib/classification/NaiveBayesSuite.scala  | 33 ++++++++++++++++++++
 2 files changed, 58 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/61e05fc5/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index c9b3ff0..b381dc2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] (
   }
 
   override def predict(testData: Vector): Double = {
+    val brzData = testData.toBreeze
     modelType match {
       case "Multinomial" =>
-        labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
+        labels (brzArgmax (brzPi + brzTheta * brzData) )
       case "Bernoulli" =>
+        if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
+          throw new SparkException(
+            s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+        }
         labels (brzArgmax (brzPi +
-          (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
+          (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
       case _ =>
         // This should never happen.
         throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -293,12 +298,29 @@ class NaiveBayes private (
       }
     }
 
+    val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
+      val values = v match {
+        case SparseVector(size, indices, values) =>
+          values
+        case DenseVector(values) =>
+          values
+      }
+      if (!values.forall(v => v == 0.0 || v == 1.0)) {
+        throw new SparkException(
+          s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
+      }
+    }
+
     // Aggregates term frequencies per label.
     // TODO: Calling combineByKey and collect creates two stages, we can implement something
     // TODO: similar to reduceByKeyLocally to save one stage.
     val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
       createCombiner = (v: Vector) => {
-        requireNonnegativeValues(v)
+        if (modelType == "Bernoulli") {
+          requireZeroOneBernoulliValues(v)
+        } else {
+          requireNonnegativeValues(v)
+        }
         (1L, v.toBreeze.toDenseVector)
       },
       mergeValue = (c: (Long, BDV[Double]), v: Vector) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/61e05fc5/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index ea89b17..40a79a1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("detect non zero or one values in Bernoulli") {
+    val badTrain = Seq(
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0)))
+
+    intercept[SparkException] {
+      NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
+    }
+
+    val okTrain = Seq(
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(0.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0))
+    )
+
+    val badPredict = Seq(
+      Vectors.dense(1.0),
+      Vectors.dense(2.0),
+      Vectors.dense(1.0),
+      Vectors.dense(0.0))
+
+    val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
+    intercept[SparkException] {
+      model.predict(sc.makeRDD(badPredict, 2)).collect()
+    }
+  }
+
   test("model save/load: 2.0 to 2.0") {
     val tempDir = Utils.createTempDir()
     val path = tempDir.toURI.toString


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org