You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/12 14:14:15 UTC

spark git commit: [SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes

Repository: spark
Updated Branches:
  refs/heads/master bc41d997e -> 22cb3a060


[SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes

## What changes were proposed in this pull request?
* Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it.
* Avoid capturing the outer object for ```modelType```.
* Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yb...@gmail.com>

Closes #15826 from yanboliang/spark-14077-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/22cb3a06
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/22cb3a06
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/22cb3a06

Branch: refs/heads/master
Commit: 22cb3a060a440205281b71686637679645454ca6
Parents: bc41d99
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sat Nov 12 06:13:22 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sat Nov 12 06:13:22 2016 -0800

----------------------------------------------------------------------
 .../spark/ml/classification/NaiveBayes.scala    | 72 ++++++++++----------
 .../spark/mllib/classification/NaiveBayes.scala |  6 +-
 2 files changed, 39 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/22cb3a06/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index b03a07a..f1a7676 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") (
   extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
   with NaiveBayesParams with DefaultParamsWritable {
 
-  import NaiveBayes.{Bernoulli, Multinomial}
+  import NaiveBayes._
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("nb"))
@@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") (
   @Since("2.1.0")
   def setWeightCol(value: String): this.type = set(weightCol, value)
 
+  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+    trainWithLabelCheck(dataset, positiveLabel = true)
+  }
+
   /**
    * ml assumes input labels in range [0, numClasses). But this implementation
    * is also called by mllib NaiveBayes which allows other kinds of input labels
-   * such as {-1, +1}. Here we use this parameter to switch between different processing logic.
-   * It should be removed when we remove mllib NaiveBayes.
+   * such as {-1, +1}. `positiveLabel` is used to determine whether the label
+   * should be checked and it should be removed when we remove mllib NaiveBayes.
    */
-  private[spark] var isML: Boolean = true
-
-  private[spark] def setIsML(isML: Boolean): this.type = {
-    this.isML = isML
-    this
-  }
-
-  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
-    if (isML) {
+  private[spark] def trainWithLabelCheck(
+      dataset: Dataset[_],
+      positiveLabel: Boolean): NaiveBayesModel = {
+    if (positiveLabel) {
       val numClasses = getNumClasses(dataset)
       if (isDefined(thresholds)) {
         require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") (
       }
     }
 
-    val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
-      val values = v match {
-        case sv: SparseVector => sv.values
-        case dv: DenseVector => dv.values
-      }
-
-      require(values.forall(_ >= 0.0),
-        s"Naive Bayes requires nonnegative feature values but found $v.")
-    }
-
-    val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
-      val values = v match {
-        case sv: SparseVector => sv.values
-        case dv: DenseVector => dv.values
-      }
-
-      require(values.forall(v => v == 0.0 || v == 1.0),
-        s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
-    }
-
+    val modelTypeValue = $(modelType)
     val requireValues: Vector => Unit = {
-      $(modelType) match {
+      modelTypeValue match {
         case Multinomial =>
           requireNonnegativeValues
         case Bernoulli =>
@@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") (
 @Since("1.6.0")
 object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
   /** String name for multinomial model type. */
-  private[spark] val Multinomial: String = "multinomial"
+  private[classification] val Multinomial: String = "multinomial"
 
   /** String name for Bernoulli model type. */
-  private[spark] val Bernoulli: String = "bernoulli"
+  private[classification] val Bernoulli: String = "bernoulli"
 
   /* Set of modelTypes that NaiveBayes supports */
-  private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
+  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+
+  private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
+    val values = v match {
+      case sv: SparseVector => sv.values
+      case dv: DenseVector => dv.values
+    }
+
+    require(values.forall(_ >= 0.0),
+      s"Naive Bayes requires nonnegative feature values but found $v.")
+  }
+
+  private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = {
+    val values = v match {
+      case sv: SparseVector => sv.values
+      case dv: DenseVector => dv.values
+    }
+
+    require(values.forall(v => v == 0.0 || v == 1.0),
+      s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
+  }
 
   @Since("1.6.0")
   override def load(path: String): NaiveBayes = super.load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/22cb3a06/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 33561be..767d056 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -364,12 +364,12 @@ class NaiveBayes private (
     val nb = new NewNaiveBayes()
       .setModelType(modelType)
       .setSmoothing(lambda)
-      .setIsML(false)
 
     val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) }
       .toDF("label", "features")
 
-    val newModel = nb.fit(dataset)
+    // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false.
+    val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false)
 
     val pi = newModel.pi.toArray
     val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0)
@@ -378,7 +378,7 @@ class NaiveBayes private (
         theta(i)(j) = v
     }
 
-    require(newModel.oldLabels != null,
+    assert(newModel.oldLabels != null,
       "The underlying ML NaiveBayes training does not produce labels.")
     new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org