Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20829#discussion_r177558064
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---
    @@ -18,56 +18,68 @@
     package org.apache.spark.ml.feature
     
     import org.apache.spark.{SparkException, SparkFunSuite}
    -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, 
NumericAttribute}
    +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, 
NominalAttribute, NumericAttribute}
     import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, 
Vectors}
     import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.ml.util.DefaultReadWriteTest
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.sql.Row
    +import org.apache.spark.sql.{Dataset, Row}
     import org.apache.spark.sql.functions.{col, udf}
     
     class VectorAssemblerSuite
       extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
     
       import testImplicits._
     
    +  @transient var dfWithNulls: Dataset[_] = _
    +
    +  override def beforeAll(): Unit = {
    +    super.beforeAll()
    +    dfWithNulls = Seq[(Long, Long, java.lang.Double, Vector, String, 
Vector, Long, String)](
    +      (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, 
Array(1), Array(3.0)), 7L, null),
    +      (2, 1, 0.0, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 6L, 
null),
    +      (3, 3, null, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, 
Array(1), Array(3.0)), 8L, null),
    +      (4, 4, null, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 9L, 
null))
    +      .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
    +  }
    +
       test("params") {
         ParamsSuite.checkParams(new VectorAssembler)
       }
     
       test("assemble") {
         import org.apache.spark.ml.feature.VectorAssembler.assemble
    -    assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
    -    assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
    +    assert(assemble(Array(1), true)(0.0) === Vectors.sparse(1, 
Array.empty, Array.empty))
    +    assert(assemble(Array(1, 1), true)(0.0, 1.0) === Vectors.sparse(2, 
Array(1), Array(1.0)))
         val dv = Vectors.dense(2.0, 0.0)
    -    assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), 
Array(2.0, 1.0)))
    +    assert(assemble(Array(1, 2, 1), true)(0.0, dv, 1.0) ===
    +      Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
         val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
    -    assert(assemble(0.0, dv, 1.0, sv) ===
    +    assert(assemble(Array(1, 2, 1, 2), true)(0.0, dv, 1.0, sv) ===
           Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
    -    for (v <- Seq(1, "a", null)) {
    -      intercept[SparkException](assemble(v))
    -      intercept[SparkException](assemble(1.0, v))
    +    for (v <- Seq(1, "a")) {
    +      intercept[SparkException](assemble(Array(1), true)(v))
    +      intercept[SparkException](assemble(Array(1, 1), true)(1.0, v))
         }
       }
     
       test("assemble should compress vectors") {
         import org.apache.spark.ml.feature.VectorAssembler.assemble
    -    val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
    +    val v1 = assemble(Array(1, 1, 1, 1), true)(0.0, 0.0, 0.0, 
Vectors.dense(4.0))
         assert(v1.isInstanceOf[SparseVector])
    -    val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), 
Array(4.0)))
    +    val sv = Vectors.sparse(1, Array(0), Array(4.0))
    +    val v2 = assemble(Array(1, 1, 1, 1), true)(1.0, 2.0, 3.0, sv)
         assert(v2.isInstanceOf[DenseVector])
       }
     
       test("VectorAssembler") {
    -    val df = Seq(
    -      (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), 
Array(3.0)), 10L)
    -    ).toDF("id", "x", "y", "name", "z", "n")
    +    val df = dfWithNulls.filter("id1 == 1").withColumn("id", col("id1"))
    --- End diff --
    
    nit: If this is for consolidation, I'm actually against this little change 
since it obscures what this test is doing and moves the input Row farther from 
the expected output row.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to