Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20829#discussion_r175915257
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---
@@ -147,4 +149,72 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}
+
+ test("assemble should keep nulls") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ assert(assemble(Seq(1, 1), true)(1.0, null) === Vectors.dense(1.0,
Double.NaN))
+ assert(assemble(Seq(1, 2), true)(1.0, null) === Vectors.dense(1.0,
Double.NaN, Double.NaN))
+ assert(assemble(Seq(1), true)(null) === Vectors.dense(Double.NaN))
+ assert(assemble(Seq(2), true)(null) === Vectors.dense(Double.NaN,
Double.NaN))
+ }
+
+ test("assemble should throw errors") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ intercept[SparkException](assemble(Seq(1, 1), false)(1.0, null) ===
+ Vectors.dense(1.0, Double.NaN))
+ intercept[SparkException](assemble(Seq(1, 2), false)(1.0, null) ===
+ Vectors.dense(1.0, Double.NaN, Double.NaN))
+ intercept[SparkException](assemble(Seq(1), false)(null) ===
Vectors.dense(Double.NaN))
+ intercept[SparkException](assemble(Seq(2), false)(null) ===
+ Vectors.dense(Double.NaN, Double.NaN))
+ }
+
+ test("get lengths function") {
+ val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector,
Long)](
+ (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2,
Array(1), Array(3.0)), 7L),
+ (2, 1, 0.0, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 6L),
+ (3, 3, null, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2,
Array(1), Array(3.0)), 8L),
+ (4, 4, null, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 9L)
+ ).toDF("id1", "id2", "x", "y", "name", "z", "n")
+ assert(VectorAssembler.getLengthsFromFirst(df, Seq("y")).exists(_ ==
"y" -> 2))
+
intercept[NullPointerException](VectorAssembler.getLengthsFromFirst(df.sort("id2"),
Seq("y")))
+ intercept[NoSuchElementException](
+ VectorAssembler.getLengthsFromFirst(df.filter("id1 > 4"), Seq("y")))
+
+ assert(VectorAssembler.getLengths(
+ df.sort("id2"), Seq("y"), VectorAssembler.SKIP_INVALID).exists(_ ==
"y" -> 2))
+ intercept[NullPointerException](VectorAssembler.getLengths(
+ df.sort("id2"), Seq("y"), VectorAssembler.ERROR_INVALID))
+ intercept[RuntimeException](VectorAssembler.getLengths(
+ df.sort("id2"), Seq("y"), VectorAssembler.KEEP_INVALID))
+ }
+
+ test("Handle Invalid should behave properly") {
+ val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector,
Long)](
+ (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2,
Array(1), Array(3.0)), 7L),
+ (2, 1, 0.0, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 6L),
+ (3, 3, null, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2,
Array(1), Array(3.0)), 8L),
+ (4, 4, null, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 9L)
+ ).toDF("id1", "id2", "x", "y", "name", "z", "n")
+
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z", "n"))
+ .setOutputCol("features")
+
+ // behavior when first row has information
+ assert(assembler.setHandleInvalid("skip").transform(df).count() == 1)
+
intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(df).collect())
+
intercept[SparkException](assembler.setHandleInvalid("error").transform(df).collect())
+
+ // numeric column is all null
+ intercept[RuntimeException](
+
assembler.setHandleInvalid("keep").transform(df.filter("id1==3")).count() == 1)
+
+ // vector column is all null
--- End diff --
ditto
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]