Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r174993897
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
@@ -85,18 +120,34 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0")
override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot
get the number of attributes
// from metadata, check the first row.
- val numAttrs =
group.numAttributes.getOrElse(first.getAs[Vector](index).size)
- Array.tabulate(numAttrs)(i =>
NumericAttribute.defaultAttr.withName(c + "_" + i))
+ (0 until length).map { i =>
NumericAttribute.defaultAttr.withName(c + "_" + i) }
+ }
+ case DoubleType =>
+ val attribute = Attribute.fromStructField(field)
+ attribute match {
+ case UnresolvedAttribute =>
+ Seq(NumericAttribute.defaultAttr.withName(c))
+ case _ =>
+ Seq(attribute.withName(c))
}
+ case _ : NumericType | BooleanType =>
+ // If the input column type is a compatible scalar type, assume
numeric.
+ Seq(NumericAttribute.defaultAttr.withName(c))
case otherType =>
throw new SparkException(s"VectorAssembler does not support the
$otherType type")
}
}
- val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
-
+ val featureAttributes = featureAttributesMap.flatten[Attribute]
+ val lengths = featureAttributesMap.map(a => a.length)
+ val metadata = new AttributeGroup($(outputCol),
featureAttributes.toArray).toMetadata()
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
+ case StringIndexer.SKIP_INVALID => (dataset.na.drop("any",
$(inputCols)), false)
--- End diff --
you can directly use `dataset.na.drop($(inputCols))`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]