You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/08/02 14:22:54 UTC
spark git commit: [SPARK-16851][ML] Incorrect threshould length in
'setThresholds()' evoke Exception
Repository: spark
Updated Branches:
refs/heads/master a1ff72e1c -> d9e0919d3
[SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception
## What changes were proposed in this pull request?
Add a length checking for threshoulds' length in method `setThreshoulds()` of classification models.
## How was this patch tested?
unit tests
Author: Zheng RuiFeng <ru...@foxmail.com>
Closes #14457 from zhengruifeng/check_setThresholds.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9e0919d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9e0919d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9e0919d
Branch: refs/heads/master
Commit: d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e
Parents: a1ff72e
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Tue Aug 2 07:22:41 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Tue Aug 2 07:22:41 2016 -0700
----------------------------------------------------------------------
.../spark/ml/classification/ProbabilisticClassifier.scala | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/d9e0919d/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 88642ab..19df8f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
/** @group setParam */
- def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+ def setThresholds(value: Array[Double]): M = {
+ require(value.length == numClasses, this.getClass.getSimpleName +
+ ".setThresholds() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${value.length}")
+ set(thresholds, value).asInstanceOf[M]
+ }
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org