You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/08/04 20:44:59 UTC
spark git commit: [SPARK-16863][ML] ProbabilisticClassifier.fit check
threshoulds' length
Repository: spark
Updated Branches:
refs/heads/master 1d781572e -> 0e2e5d7d0
[SPARK-16863][ML] ProbabilisticClassifier.fit check threshoulds' length
## What changes were proposed in this pull request?
Add threshoulds' length checking for Classifiers which extends ProbabilisticClassifier
## How was this patch tested?
unit tests and manual tests
Author: Zheng RuiFeng <ru...@foxmail.com>
Closes #14470 from zhengruifeng/classifier_check_setThreshoulds_length.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e2e5d7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e2e5d7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e2e5d7d
Branch: refs/heads/master
Commit: 0e2e5d7d0b42226c61c3200fd63d2831c558519d
Parents: 1d78157
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Thu Aug 4 21:44:54 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Aug 4 21:44:54 2016 +0100
----------------------------------------------------------------------
.../spark/ml/classification/DecisionTreeClassifier.scala | 7 +++++++
.../apache/spark/ml/classification/LogisticRegression.scala | 6 ++++++
.../org/apache/spark/ml/classification/NaiveBayes.scala | 8 ++++++++
.../spark/ml/classification/RandomForestClassifier.scala | 7 +++++++
4 files changed, 28 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 7129301..bb192ab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 7694773..90baa41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index ab977c8..f939a1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") (
setDefault(modelType -> OldNaiveBayes.Multinomial)
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+ val numClasses = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val oldDataset: RDD[OldLabeledPoint] =
extractLabeledPoints(dataset).map(OldLabeledPoint.fromML)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
http://git-wip-us.apache.org/repos/asf/spark/blob/0e2e5d7d/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 4ab132e..52345b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org