Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19621#discussion_r148172976
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
@@ -130,21 +152,33 @@ class StringIndexer @Since("1.4.0") (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(value: Array[String]): this.type = set(outputCols,
value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
- val values = dataset.na.drop(Array($(inputCol)))
- .select(col($(inputCol)).cast(StringType))
- .rdd.map(_.getString(0))
- val labels = $(stringOrderType) match {
- case StringIndexer.frequencyDesc =>
values.countByValue().toSeq.sortBy(-_._2)
- .map(_._1).toArray
- case StringIndexer.frequencyAsc =>
values.countByValue().toSeq.sortBy(_._2)
- .map(_._1).toArray
- case StringIndexer.alphabetDesc =>
values.distinct.collect.sortWith(_ > _)
- case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_
< _)
+
+ val labelsArray = for (inputCol <- getInOutCols._1) yield {
+ val values = dataset.na.drop(Array(inputCol))
+ .select(col(inputCol).cast(StringType))
+ .rdd.map(_.getString(0))
--- End diff --
This gets the values for each input column sequentially. Can we get the
values for all input columns at one run?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]