zhengruifeng commented on code in PR #36049:
URL: https://github.com/apache/spark/pull/36049#discussion_r841303261
##########
mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala:
##########
@@ -138,4 +140,61 @@ private[spark] object DatasetUtils {
case Row(point: Vector) => OldVectors.fromML(point)
}
}
+
+ /**
+ * Get the number of classes. This looks in column metadata first, and if
that is missing,
+ * then this assumes classes are indexed 0,1,...,numClasses-1 and computes
numClasses
+ * by finding the maximum label value.
+ *
+ * Label validation (ensuring all labels are integers >= 0) needs to be
handled elsewhere,
+ * such as in `extractLabeledPoints()`.
+ *
+ * @param dataset Dataset which contains a column [[labelCol]]
+ * @param maxNumClasses Maximum number of classes allowed when inferred
from data. If numClasses
+ * is specified in the metadata, then maxNumClasses is
ignored.
+ * @return number of classes
+ * @throws IllegalArgumentException if metadata does not specify
numClasses, and the
+ * actual numClasses exceeds maxNumClasses
+ */
+ private[ml] def getNumClasses(
+ dataset: Dataset[_],
+ labelCol: String,
+ maxNumClasses: Int = 100): Int = {
+ MetadataUtils.getNumClasses(dataset.schema(labelCol)) match {
+ case Some(n: Int) => n
+ case None =>
+ // Get number of classes from dataset itself.
+ val maxLabelRow: Array[Row] = dataset
+ .select(max(checkClassificationLabels(labelCol,
Some(maxNumClasses))))
+ .take(1)
+ if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
+ throw new SparkException("ML algorithm was given empty dataset.")
+ }
+ val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+ require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label
value =" +
+ s" $maxDoubleLabel but requires integers in range [0, ...
${Int.MaxValue})")
+ val numClasses = maxDoubleLabel.toInt + 1
+ require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses
from label values" +
+ s" in column $labelCol, but this exceeded the max numClasses
($maxNumClasses) allowed" +
+ s" to be inferred from values. To avoid this error for labels with
> $maxNumClasses" +
+ s" classes, specify numClasses explicitly in the metadata; this can
be done by applying" +
+ s" StringIndexer to the label column.")
+ logInfo(this.getClass.getCanonicalName + s" inferred $numClasses
classes for" +
+ s" labelCol=$labelCol since numClasses was not specified in the
column metadata.")
+ numClasses
+ }
+ }
+
+ /**
+ * Obtain the number of features in a vector column.
+ * If no metadata is available, extract it from the dataset.
+ */
+ private[ml] def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int
= {
+ MetadataUtils.getNumFeatures(dataset.schema(vectorCol)) match {
Review Comment:
ok, will swith back to getOrElse
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe e-mail: [email protected]