Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r177503206
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
@@ -49,32 +55,64 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0")
override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+
+ /**
+ * Param for how to handle invalid data (NULL values). Options are
'skip' (filter out rows with
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant
number of NaN in the
+ * output). Column lengths are taken from the size of ML Attribute
Group, which can be set using
+ * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column
lengths can also be inferred
+ * from first rows of the data since it is safe to do so but only in
case of 'error' or 'skip'.
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.4.0")
+ override val handleInvalid: Param[String] = new Param[String](this,
"handleInvalid",
+ """
+ | Param for how to handle invalid data (NULL values). Options are
'skip' (filter out rows with
+ | invalid data), 'error' (throw an error), or 'keep' (return relevant
number of NaN in the
+ | output). Column lengths are taken from the size of ML Attribute
Group, which can be set using
+ | `VectorSizeHint` in a pipeline before `VectorAssembler`. Column
lengths can also be inferred
+ | from first rows of the data since it is safe to do so but only in
case of 'error' or 'skip'.
+ | """.stripMargin.replaceAll("\n", " "),
+ ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.toDF.first()
- val attrs = $(inputCols).flatMap { c =>
+
+ val vectorCols = $(inputCols).toSeq.filter { c =>
--- End diff --
nit: Is toSeq extraneous?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]