Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20829#discussion_r177560225
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---
    @@ -147,4 +159,88 @@ class VectorAssemblerSuite
           .filter(vectorUDF($"features") > 1)
           .count() == 1)
       }
    +
    +  test("assemble should keep nulls when keepInvalid is true") {
    +    import org.apache.spark.ml.feature.VectorAssembler.assemble
    +    assert(assemble(Array(1, 1), true)(1.0, null) === Vectors.dense(1.0, 
Double.NaN))
    +    assert(assemble(Array(1, 2), true)(1.0, null) === Vectors.dense(1.0, 
Double.NaN, Double.NaN))
    +    assert(assemble(Array(1), true)(null) === Vectors.dense(Double.NaN))
    +    assert(assemble(Array(2), true)(null) === Vectors.dense(Double.NaN, 
Double.NaN))
    +  }
    +
    +  test("assemble should throw errors when keepInvalid is false") {
    +    import org.apache.spark.ml.feature.VectorAssembler.assemble
    +    intercept[SparkException](assemble(Array(1, 1), false)(1.0, null))
    +    intercept[SparkException](assemble(Array(1, 2), false)(1.0, null))
    +    intercept[SparkException](assemble(Array(1), false)(null))
    +    intercept[SparkException](assemble(Array(2), false)(null))
    +  }
    +
    +  test("get lengths functions") {
    +    import org.apache.spark.ml.feature.VectorAssembler._
    +    val df = dfWithNulls
    +    assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
    +    
assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"),
 Seq("y")))
    +      .getMessage.contains("VectorSizeHint"))
    +    
assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1
 > 4"),
    +      Seq("y"))).getMessage.contains("VectorSizeHint"))
    +
    +    assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == 
"y" -> 2))
    +    assert(intercept[NullPointerException](getLengths(df.sort("id2"), 
Seq("y"), ERROR_INVALID))
    +      .getMessage.contains("VectorSizeHint"))
    +    assert(intercept[RuntimeException](getLengths(df.sort("id2"), 
Seq("y"), KEEP_INVALID))
    +      .getMessage.contains("VectorSizeHint"))
    +  }
    +
    +  test("Handle Invalid should behave properly") {
    +    val assembler = new VectorAssembler()
    +      .setInputCols(Array("x", "y", "z", "n"))
    +      .setOutputCol("features")
    +
    +    def run_with_metadata(mode: String, additional_filter: String = 
"true"): Dataset[_] = {
    +      val attributeY = new AttributeGroup("y", 2)
    +      val subAttributesOfZ = Array(NumericAttribute.defaultAttr, 
NumericAttribute.defaultAttr)
    +      val attributeZ = new AttributeGroup(
    +        "z",
    +        Array[Attribute](
    +          NumericAttribute.defaultAttr.withName("foo"),
    +          NumericAttribute.defaultAttr.withName("bar")))
    +      val dfWithMetadata = dfWithNulls.withColumn("y", col("y"), 
attributeY.toMetadata())
    +        .withColumn("z", col("z"), 
attributeZ.toMetadata()).filter(additional_filter)
    +      val output = 
assembler.setHandleInvalid(mode).transform(dfWithMetadata)
    +      output.collect()
    +      output
    +    }
    +    def run_with_first_row(mode: String): Dataset[_] = {
    --- End diff --
    
    style: Put empty line between functions


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to