Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/12663#discussion_r60967938
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala ---
@@ -62,6 +65,76 @@ abstract class Classifier[
def setRawPredictionCol(value: String): E = set(rawPredictionCol,
value).asInstanceOf[E]
// TODO: defaultEvaluator (follow-up PR)
+
+ /**
+ * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ * @throws SparkException if any label is not an integer >= 0
+ */
+ override protected def extractLabeledPoints(dataset: Dataset[_]):
RDD[LabeledPoint] = {
+ dataset.select(col($(labelCol)).cast(DoubleType),
col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) =>
+ require(label % 1 == 0 && label >= 0, s"Classifier was given
dataset with invalid label" +
+ s" $label. Labels must be integers in range [0, 1, ...,
numClasses-1]")
+ LabeledPoint(label, features)
+ }
+ }
+
+ /**
+ * Get the number of classes. This looks in column metadata first, and
if that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and
computes numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be
handled elsewhere,
+ * such as in [[extractLabeledPoints()]].
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred
from data. If numClasses
+ * is specified in the metadata, then
maxNumClasses is ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify
numClasses, and the
+ * actual numClasses exceeds
maxNumClasses
+ */
+ protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int =
1000): Int = {
+ MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+ case Some(n: Int) => n
+ case None =>
+ // Get number of classes from dataset itself.
+ val maxLabelRow: Array[Row] =
dataset.select(max($(labelCol))).take(1)
+ if (maxLabelRow.isEmpty) {
+ throw new SparkException("ML algorithm was given empty dataset.")
+ }
+ val maxLabel: Int = maxLabelRow.head.getDouble(0).toInt
+ val numClasses = maxLabel + 1
+ require(numClasses <= maxNumClasses, s"Classifier inferred
$numClasses from label values" +
--- End diff --
I agree we should set a limit here. It might not be clear to someone who
receives this error that they _can_ have more than 1000 classes when they set
the metadata themselves. Maybe the last sentence could say "For labels
containing more than $maxNumClasses, specify the numClasses explicitly in
metadata, such as ..."
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]