You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/08/02 14:22:54 UTC

spark git commit: [SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception

Repository: spark
Updated Branches:
  refs/heads/master a1ff72e1c -> d9e0919d3


[SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception

## What changes were proposed in this pull request?
Add a length checking for threshoulds' length in method `setThreshoulds()`  of classification models.

## How was this patch tested?
unit tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #14457 from zhengruifeng/check_setThresholds.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9e0919d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9e0919d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9e0919d

Branch: refs/heads/master
Commit: d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e
Parents: a1ff72e
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Tue Aug 2 07:22:41 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Tue Aug 2 07:22:41 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/ProbabilisticClassifier.scala     | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d9e0919d/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 88642ab..19df8f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[
   def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
 
   /** @group setParam */
-  def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+  def setThresholds(value: Array[Double]): M = {
+    require(value.length == numClasses, this.getClass.getSimpleName +
+      ".setThresholds() called with non-matching numClasses and thresholds.length." +
+      s" numClasses=$numClasses, but thresholds has length ${value.length}")
+    set(thresholds, value).asInstanceOf[M]
+  }
 
   /**
    * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org