zhengruifeng created SPARK-16863:
------------------------------------
Summary: ProbabilisticClassifier.fit check threshoulds' length
Key: SPARK-16863
URL: https://issues.apache.org/jira/browse/SPARK-16863
Project: Spark
Issue Type: Improvement
Components: ML
Reporter: zhengruifeng
Priority: Minor
{code}
val path =
"./spark-2.0.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt"
val data = spark.read.format("libsvm").load(path)
val rf = new RandomForestClassifier
rf.setThresholds(Array(0.1,0.2,0.3,0.4,0.5))
val rfm = rf.fit(data)
rfm: org.apache.spark.ml.classification.RandomForestClassificationModel =
RandomForestClassificationModel (uid=rfc_fec31a5b954d) with 20 trees
rfm.numClasses
res2: Int = 3
rfm.getThresholds
res3: Array[Double] = Array(0.1, 0.2, 0.3, 0.4, 0.5)
rfm.transform(data)
java.lang.IllegalArgumentException: requirement failed:
RandomForestClassificationModel.transform() called with non-matching numClasses
and thresholds.length. numClasses=3, but thresholds has length 5
at scala.Predef$.require(Predef.scala:224)
at
org.apache.spark.ml.classification.ProbabilisticClassificationModel.transform(ProbabilisticClassifier.scala:101)
... 72 elided
{code}
{{ProbabilisticClassifier.fit()}} should throw some exception if it's
threshoulds is set incorrectly.
--
This message was sent by Atlassian JIRA
(v6.3.4#6332)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]