Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/20829#discussion_r177547735 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala --- @@ -136,34 +181,106 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. + |""".stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. + |""".stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths(dataset: Dataset[_], columns: Seq[String], + handleInvalid: String) = { + val group_sizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missing_columns: Seq[String] = group_sizes.filter(_._2 == -1).keys.toSeq + val first_sizes: Map[String, Int] = (missing_columns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missing_columns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missing_columns), missing_columns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths for 'keep invalid' mode. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. + |""".stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + group_sizes ++ first_sizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { + /** + * Returns a UDF that has the required information to assemble each row. --- End diff -- nit: When people say "UDF," they generally mean a Spark SQL UDF. This is just a function, not a SQL UDF.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org