Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/20829#discussion_r175909539 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala --- @@ -136,34 +172,88 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + + private[feature] def getLengthsFromFirst(dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF.select(columns.map(col): _*).first + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + "Saw null value on the first row: " + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + "Cannot infer vector size from all empty DataFrame" + e.toString) + } + } + + private[feature] def getLengths(dataset: Dataset[_], columns: Seq[String], + handleInvalid: String) = { + val group_sizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missing_columns: Seq[String] = group_sizes.filter(_._2 == -1).keys.toSeq + val first_sizes: Map[String, Int] = (missing_columns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getLengthsFromFirst(dataset, missing_columns) + case (true, VectorAssembler.SKIP_INVALID) => + getLengthsFromFirst(dataset.na.drop, missing_columns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + "Consider using VectorSizeHint for columns: " + missing_columns.mkString("[", ",", "]")) + case (_, _) => Map.empty + } + group_sizes ++ first_sizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { + private[feature] def assemble(lengths: Seq[Int], keepInvalid: Boolean)(vv: Any*): Vector = { --- End diff -- Also, I'd add doc explaining requirements, especially that this assumes that lengths and vv have the same length.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org