You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/08/04 20:44:59 UTC

spark git commit: [SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length

Repository: spark
Updated Branches:
  refs/heads/master 1d781572e -> 0e2e5d7d0


[SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length

## What changes were proposed in this pull request?

Add threshoulds' length checking for Classifiers which extends ProbabilisticClassifier

## How was this patch tested?

unit tests and manual tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #14470 from zhengruifeng/classifier_check_setThreshoulds_length.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e2e5d7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e2e5d7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e2e5d7d

Branch: refs/heads/master
Commit: 0e2e5d7d0b42226c61c3200fd63d2831c558519d
Parents: 1d78157
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Thu Aug 4 21:44:54 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Aug 4 21:44:54 2016 +0100

----------------------------------------------------------------------
 .../spark/ml/classification/DecisionTreeClassifier.scala     | 7 +++++++
 .../apache/spark/ml/classification/LogisticRegression.scala  | 6 ++++++
 .../org/apache/spark/ml/classification/NaiveBayes.scala      | 8 ++++++++
 .../spark/ml/classification/RandomForestClassifier.scala     | 7 +++++++
 4 files changed, 28 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 7129301..bb192ab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = getNumClasses(dataset)
+
+    if (isDefined(thresholds)) {
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".train() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+    }
+
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 7694773..90baa41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") (
     val numClasses = histogram.length
     val numFeatures = summarizer.mean.size
 
+    if (isDefined(thresholds)) {
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".train() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+    }
+
     instr.logNumClasses(numClasses)
     instr.logNumFeatures(numFeatures)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index ab977c8..f939a1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") (
   setDefault(modelType -> OldNaiveBayes.Multinomial)
 
   override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+    val numClasses = getNumClasses(dataset)
+
+    if (isDefined(thresholds)) {
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".train() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+    }
+
     val oldDataset: RDD[OldLabeledPoint] =
       extractLabeledPoints(dataset).map(OldLabeledPoint.fromML)
     val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))

http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 4ab132e..52345b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") (
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = getNumClasses(dataset)
+
+    if (isDefined(thresholds)) {
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".train() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+    }
+
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org