Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20146#discussion_r161039325
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
    @@ -33,12 +33,38 @@ class StringIndexerSuite
     
       test("params") {
         ParamsSuite.checkParams(new StringIndexer)
    -    val model = new StringIndexerModel("indexer", Array("a", "b"))
    +    val model = new StringIndexerModel("indexer", Array(Array("a", "b")))
         val modelWithoutUid = new StringIndexerModel(Array("a", "b"))
         ParamsSuite.checkParams(model)
         ParamsSuite.checkParams(modelWithoutUid)
       }
     
    +  test("params: input/output columns") {
    +    val stringIndexerSingleCol = new StringIndexer()
    +      .setInputCol("in").setOutputCol("out")
    +    val inOutCols1 = stringIndexerSingleCol.getInOutCols()
    +    assert(inOutCols1._1 === Array("in"))
    +    assert(inOutCols1._2 === Array("out"))
    +
    +    val stringIndexerMultiCol = new StringIndexer()
    +      .setInputCols(Array("in1", "in2")).setOutputCols(Array("out1", 
"out2"))
    +    val inOutCols2 = stringIndexerMultiCol.getInOutCols()
    +    assert(inOutCols2._1 === Array("in1", "in2"))
    +    assert(inOutCols2._2 === Array("out1", "out2"))
    +
    +    intercept[IllegalArgumentException] {
    +      new StringIndexer().setInputCol("in").setOutputCols(Array("out1", 
"out2")).getInOutCols()
    --- End diff --
    
    It seems better that, use the way calling `stringIndexer.fit` and check 
exception thrown.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to