Github user yogeshg commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20829#discussion_r178605922
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
    @@ -49,32 +55,64 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setHandleInvalid(value: String): this.type = set(handleInvalid, 
value)
    +
    +  /**
    +   * Param for how to handle invalid data (NULL values). Options are 
'skip' (filter out rows with
    +   * invalid data), 'error' (throw an error), or 'keep' (return relevant 
number of NaN in the
    +   * output). Column lengths are taken from the size of ML Attribute 
Group, which can be set using
    +   * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column 
lengths can also be inferred
    +   * from first rows of the data since it is safe to do so but only in 
case of 'error' or 'skip'.
    +   * Default: "error"
    +   * @group param
    +   */
    +  @Since("2.4.0")
    +  override val handleInvalid: Param[String] = new Param[String](this, 
"handleInvalid",
    +    """
    +    | Param for how to handle invalid data (NULL values). Options are 
'skip' (filter out rows with
    +    | invalid data), 'error' (throw an error), or 'keep' (return relevant 
number of NaN in the
    +    | output). Column lengths are taken from the size of ML Attribute 
Group, which can be set using
    +    | `VectorSizeHint` in a pipeline before `VectorAssembler`. Column 
lengths can also be inferred
    +    | from first rows of the data since it is safe to do so but only in 
case of 'error' or 'skip'.
    +    | """.stripMargin.replaceAll("\n", " "),
    +    ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
    +
    +  setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
    +
       @Since("2.0.0")
       override def transform(dataset: Dataset[_]): DataFrame = {
         transformSchema(dataset.schema, logging = true)
         // Schema transformation.
         val schema = dataset.schema
    -    lazy val first = dataset.toDF.first()
    -    val attrs = $(inputCols).flatMap { c =>
    +
    +    val vectorCols = $(inputCols).toSeq.filter { c =>
    +      schema(c).dataType match {
    +        case _: VectorUDT => true
    +        case _ => false
    +      }
    +    }
    +    val vectorColsLengths = VectorAssembler.getLengths(dataset, 
vectorCols, $(handleInvalid))
    +
    +    val featureAttributesMap = $(inputCols).toSeq.map { c =>
    --- End diff --
    
    We need the map to find out the length of vectors, unless there's a way to 
do this in one mapping way, I think it might be better than to call first a 
`map` and then a `flatMap`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to