Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20829#discussion_r175909539
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
    @@ -136,34 +172,88 @@ class VectorAssembler @Since("1.4.0") 
(@Since("1.4.0") override val uid: String)
     @Since("1.6.0")
     object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
     
    +  private[feature] val SKIP_INVALID: String = "skip"
    +  private[feature] val ERROR_INVALID: String = "error"
    +  private[feature] val KEEP_INVALID: String = "keep"
    +  private[feature] val supportedHandleInvalids: Array[String] =
    +    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
    +
    +
    +  private[feature] def getLengthsFromFirst(dataset: Dataset[_],
    +                                           columns: Seq[String]): 
Map[String, Int] = {
    +    try {
    +      val first_row = dataset.toDF.select(columns.map(col): _*).first
    +      columns.zip(first_row.toSeq).map {
    +        case (c, x) => c -> x.asInstanceOf[Vector].size
    +      }.toMap
    +    } catch {
    +      case e: NullPointerException => throw new NullPointerException(
    +        "Saw null value on the first row: " + e.toString)
    +      case e: NoSuchElementException => throw new NoSuchElementException(
    +        "Cannot infer vector size from all empty DataFrame" + e.toString)
    +    }
    +  }
    +
    +  private[feature] def getLengths(dataset: Dataset[_], columns: 
Seq[String],
    +                                  handleInvalid: String) = {
    +    val group_sizes = columns.map { c =>
    +      c -> AttributeGroup.fromStructField(dataset.schema(c)).size
    +    }.toMap
    +    val missing_columns: Seq[String] = group_sizes.filter(_._2 == 
-1).keys.toSeq
    +    val first_sizes: Map[String, Int] = (missing_columns.nonEmpty, 
handleInvalid) match {
    +      case (true, VectorAssembler.ERROR_INVALID) =>
    +        getLengthsFromFirst(dataset, missing_columns)
    +      case (true, VectorAssembler.SKIP_INVALID) =>
    +        getLengthsFromFirst(dataset.na.drop, missing_columns)
    +      case (true, VectorAssembler.KEEP_INVALID) => throw new 
RuntimeException(
    +        "Consider using VectorSizeHint for columns: " + 
missing_columns.mkString("[", ",", "]"))
    +      case (_, _) => Map.empty
    +    }
    +    group_sizes ++ first_sizes
    +  }
    +
    +
       @Since("1.6.0")
       override def load(path: String): VectorAssembler = super.load(path)
     
    -  private[feature] def assemble(vv: Any*): Vector = {
    +  private[feature] def assemble(lengths: Seq[Int], keepInvalid: 
Boolean)(vv: Any*): Vector = {
    --- End diff --
    
    Also, I'd add doc explaining requirements, especially that this assumes 
that lengths and vv have the same length.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to