Repository: spark Updated Branches: refs/heads/master a1ff72e1c -> d9e0919d3
[SPARK-16851][ML] Incorrect threshould length in 'setThresholds()' evoke Exception ## What changes were proposed in this pull request? Add a length checking for threshoulds' length in method `setThreshoulds()` of classification models. ## How was this patch tested? unit tests Author: Zheng RuiFeng <ruife...@foxmail.com> Closes #14457 from zhengruifeng/check_setThresholds. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9e0919d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9e0919d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9e0919d Branch: refs/heads/master Commit: d9e0919d30e9f79a0eb1ceb8d1b5e9fc58cf085e Parents: a1ff72e Author: Zheng RuiFeng <ruife...@foxmail.com> Authored: Tue Aug 2 07:22:41 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Tue Aug 2 07:22:41 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/classification/ProbabilisticClassifier.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d9e0919d/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 88642ab..19df8f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -83,7 +83,12 @@ abstract class ProbabilisticClassificationModel[ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] /** @group setParam */ - def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + def setThresholds(value: Array[Double]): M = { + require(value.length == numClasses, this.getClass.getSimpleName + + ".setThresholds() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${value.length}") + set(thresholds, value).asInstanceOf[M] + } /** * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org