Github user yogeshg commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r178605922
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
@@ -49,32 +55,64 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0")
override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+
+ /**
+ * Param for how to handle invalid data (NULL values). Options are
'skip' (filter out rows with
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant
number of NaN in the
+ * output). Column lengths are taken from the size of ML Attribute
Group, which can be set using
+ * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column
lengths can also be inferred
+ * from first rows of the data since it is safe to do so but only in
case of 'error' or 'skip'.
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.4.0")
+ override val handleInvalid: Param[String] = new Param[String](this,
"handleInvalid",
+ """
+ | Param for how to handle invalid data (NULL values). Options are
'skip' (filter out rows with
+ | invalid data), 'error' (throw an error), or 'keep' (return relevant
number of NaN in the
+ | output). Column lengths are taken from the size of ML Attribute
Group, which can be set using
+ | `VectorSizeHint` in a pipeline before `VectorAssembler`. Column
lengths can also be inferred
+ | from first rows of the data since it is safe to do so but only in
case of 'error' or 'skip'.
+ | """.stripMargin.replaceAll("\n", " "),
+ ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.toDF.first()
- val attrs = $(inputCols).flatMap { c =>
+
+ val vectorCols = $(inputCols).toSeq.filter { c =>
+ schema(c).dataType match {
+ case _: VectorUDT => true
+ case _ => false
+ }
+ }
+ val vectorColsLengths = VectorAssembler.getLengths(dataset,
vectorCols, $(handleInvalid))
+
+ val featureAttributesMap = $(inputCols).toSeq.map { c =>
--- End diff --
We need the map to find out the length of vectors, unless there's a way to
do this in one mapping way, I think it might be better than to call first a
`map` and then a `flatMap`.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]