Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20829#discussion_r178620200
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
    @@ -136,34 +181,106 @@ class VectorAssembler @Since("1.4.0") 
(@Since("1.4.0") override val uid: String)
     @Since("1.6.0")
     object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
     
    +  private[feature] val SKIP_INVALID: String = "skip"
    +  private[feature] val ERROR_INVALID: String = "error"
    +  private[feature] val KEEP_INVALID: String = "keep"
    +  private[feature] val supportedHandleInvalids: Array[String] =
    +    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
    +
    +  /**
    +   * Infers lengths of vector columns from the first row of the dataset
    +   * @param dataset the dataset
    +   * @param columns name of vector columns whose lengths need to be 
inferred
    +   * @return map of column names to lengths
    +   */
    +  private[feature] def getVectorLengthsFromFirstRow(
    +      dataset: Dataset[_], columns: Seq[String]): Map[String, Int] = {
    +    try {
    +      val first_row = dataset.toDF().select(columns.map(col): _*).first()
    +      columns.zip(first_row.toSeq).map {
    +        case (c, x) => c -> x.asInstanceOf[Vector].size
    +      }.toMap
    +    } catch {
    +      case e: NullPointerException => throw new NullPointerException(
    +        s"""Encountered null value while inferring lengths from the first 
row. Consider using
    +           |VectorSizeHint to add metadata for columns: 
${columns.mkString("[", ", ", "]")}.
    +           |""".stripMargin.replaceAll("\n", " ") + e.toString)
    +      case e: NoSuchElementException => throw new NoSuchElementException(
    +        s"""Encountered empty dataframe while inferring lengths from the 
first row. Consider using
    +           |VectorSizeHint to add metadata for columns: 
${columns.mkString("[", ", ", "]")}.
    +           |""".stripMargin.replaceAll("\n", " ") + e.toString)
    +    }
    +  }
    +
    +  private[feature] def getLengths(dataset: Dataset[_], columns: 
Seq[String],
    +                                  handleInvalid: String) = {
    +    val group_sizes = columns.map { c =>
    +      c -> AttributeGroup.fromStructField(dataset.schema(c)).size
    +    }.toMap
    +    val missing_columns: Seq[String] = group_sizes.filter(_._2 == 
-1).keys.toSeq
    +    val first_sizes: Map[String, Int] = (missing_columns.nonEmpty, 
handleInvalid) match {
    --- End diff --
    
    ping


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to