Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r177559339
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---
@@ -147,4 +159,88 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}
+
+ test("assemble should keep nulls when keepInvalid is true") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ assert(assemble(Array(1, 1), true)(1.0, null) === Vectors.dense(1.0,
Double.NaN))
+ assert(assemble(Array(1, 2), true)(1.0, null) === Vectors.dense(1.0,
Double.NaN, Double.NaN))
+ assert(assemble(Array(1), true)(null) === Vectors.dense(Double.NaN))
+ assert(assemble(Array(2), true)(null) === Vectors.dense(Double.NaN,
Double.NaN))
+ }
+
+ test("assemble should throw errors when keepInvalid is false") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ intercept[SparkException](assemble(Array(1, 1), false)(1.0, null))
+ intercept[SparkException](assemble(Array(1, 2), false)(1.0, null))
+ intercept[SparkException](assemble(Array(1), false)(null))
+ intercept[SparkException](assemble(Array(2), false)(null))
+ }
+
+ test("get lengths functions") {
+ import org.apache.spark.ml.feature.VectorAssembler._
+ val df = dfWithNulls
+ assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
+
assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"),
Seq("y")))
+ .getMessage.contains("VectorSizeHint"))
+
assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1
> 4"),
+ Seq("y"))).getMessage.contains("VectorSizeHint"))
+
+ assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ ==
"y" -> 2))
+ assert(intercept[NullPointerException](getLengths(df.sort("id2"),
Seq("y"), ERROR_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[RuntimeException](getLengths(df.sort("id2"),
Seq("y"), KEEP_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ }
+
+ test("Handle Invalid should behave properly") {
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z", "n"))
+ .setOutputCol("features")
+
+ def run_with_metadata(mode: String, additional_filter: String =
"true"): Dataset[_] = {
--- End diff --
style: use camelCase
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]