Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20146#discussion_r161039325
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
@@ -33,12 +33,38 @@ class StringIndexerSuite
test("params") {
ParamsSuite.checkParams(new StringIndexer)
- val model = new StringIndexerModel("indexer", Array("a", "b"))
+ val model = new StringIndexerModel("indexer", Array(Array("a", "b")))
val modelWithoutUid = new StringIndexerModel(Array("a", "b"))
ParamsSuite.checkParams(model)
ParamsSuite.checkParams(modelWithoutUid)
}
+ test("params: input/output columns") {
+ val stringIndexerSingleCol = new StringIndexer()
+ .setInputCol("in").setOutputCol("out")
+ val inOutCols1 = stringIndexerSingleCol.getInOutCols()
+ assert(inOutCols1._1 === Array("in"))
+ assert(inOutCols1._2 === Array("out"))
+
+ val stringIndexerMultiCol = new StringIndexer()
+ .setInputCols(Array("in1", "in2")).setOutputCols(Array("out1",
"out2"))
+ val inOutCols2 = stringIndexerMultiCol.getInOutCols()
+ assert(inOutCols2._1 === Array("in1", "in2"))
+ assert(inOutCols2._2 === Array("out1", "out2"))
+
+ intercept[IllegalArgumentException] {
+ new StringIndexer().setInputCol("in").setOutputCols(Array("out1",
"out2")).getInOutCols()
--- End diff --
It seems better that, use the way calling `stringIndexer.fit` and check
exception thrown.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]