Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r174994214
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
@@ -49,32 +51,65 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0")
override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("1.6.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+
+ /**
+ * Param for how to handle invalid data (NULL values). Options are
'skip' (filter out rows with
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant
number of NaN in the
+ * output).
+ * Default: "error"
+ * @group param
+ */
+ @Since("1.6.0")
+ override val handleInvalid: Param[String] = new Param[String](this,
"handleInvalid",
+ "Hhow to handle invalid data (NULL values). Options are 'skip' (filter
out rows with " +
+ "invalid data), 'error' (throw an error), or 'keep' (return relevant
number of NaN " +
+ "in the * output).",
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.toDF.first()
- val attrs = $(inputCols).flatMap { c =>
+
+ val featureAttributesMap: Seq[Seq[Attribute]] = $(inputCols).toSeq.map
{ c =>
val field = schema(c)
- val index = schema.fieldIndex(c)
field.dataType match {
- case DoubleType =>
- val attr = Attribute.fromStructField(field)
- // If the input column doesn't have ML attribute, assume numeric.
- if (attr == UnresolvedAttribute) {
- Some(NumericAttribute.defaultAttr.withName(c))
- } else {
- Some(attr.withName(c))
- }
- case _: NumericType | BooleanType =>
- // If the input column type is a compatible scalar type, assume
numeric.
- Some(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
- val group = AttributeGroup.fromStructField(field)
- if (group.attributes.isDefined) {
- // If attributes are defined, copy them with updated names.
- group.attributes.get.zipWithIndex.map { case (attr, i) =>
+ val attributeGroup = AttributeGroup.fromStructField(field)
+ var length = attributeGroup.size
+ val isMissingNumAttrs = -1 == length
+ if (isMissingNumAttrs && dataset.isStreaming) {
+ // this condition is checked for every column, but should be
cheap
+ throw new RuntimeException(
+ s"""
+ |VectorAssembler cannot dynamically determine the size of
vectors for streaming
+ |data. Consider applying VectorSizeHint to ${c} so that
this transformer can be
+ |used to transform streaming inputs.
+ """.stripMargin.replaceAll("\n", " "))
+ }
+ if (isMissingNumAttrs) {
+ val column = dataset.select(c).na.drop()
--- End diff --
* The var name `column` isn't good. `colDataset` is better.
* An optional optimization is one-pass scanning the dataset and count
non-null rows for each "missing num attrs" columns.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]