Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r175906193
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
@@ -136,34 +172,88 @@ class VectorAssembler @Since("1.4.0")
(@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
+
+ private[feature] def getLengthsFromFirst(dataset: Dataset[_],
+ columns: Seq[String]):
Map[String, Int] = {
+ try {
+ val first_row = dataset.toDF.select(columns.map(col): _*).first
+ columns.zip(first_row.toSeq).map {
+ case (c, x) => c -> x.asInstanceOf[Vector].size
+ }.toMap
+ } catch {
+ case e: NullPointerException => throw new NullPointerException(
+ "Saw null value on the first row: " + e.toString)
+ case e: NoSuchElementException => throw new NoSuchElementException(
+ "Cannot infer vector size from all empty DataFrame" + e.toString)
+ }
+ }
+
+ private[feature] def getLengths(dataset: Dataset[_], columns:
Seq[String],
+ handleInvalid: String) = {
+ val group_sizes = columns.map { c =>
+ c -> AttributeGroup.fromStructField(dataset.schema(c)).size
+ }.toMap
+ val missing_columns: Seq[String] = group_sizes.filter(_._2 ==
-1).keys.toSeq
+ val first_sizes: Map[String, Int] = (missing_columns.nonEmpty,
handleInvalid) match {
+ case (true, VectorAssembler.ERROR_INVALID) =>
+ getLengthsFromFirst(dataset, missing_columns)
+ case (true, VectorAssembler.SKIP_INVALID) =>
+ getLengthsFromFirst(dataset.na.drop, missing_columns)
+ case (true, VectorAssembler.KEEP_INVALID) => throw new
RuntimeException(
+ "Consider using VectorSizeHint for columns: " +
missing_columns.mkString("[", ",", "]"))
+ case (_, _) => Map.empty
+ }
+ group_sizes ++ first_sizes
+ }
+
+
@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)
- private[feature] def assemble(vv: Any*): Vector = {
+ private[feature] def assemble(lengths: Seq[Int], keepInvalid:
Boolean)(vv: Any*): Vector = {
--- End diff --
nit: Use Array[Int] for faster access
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]