Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19621#discussion_r152252491
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -130,21 +160,49 @@ class StringIndexer @Since("1.4.0") (
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, 
value)
    +
       @Since("2.0.0")
       override def fit(dataset: Dataset[_]): StringIndexerModel = {
         transformSchema(dataset.schema, logging = true)
    -    val values = dataset.na.drop(Array($(inputCol)))
    -      .select(col($(inputCol)).cast(StringType))
    -      .rdd.map(_.getString(0))
    -    val labels = $(stringOrderType) match {
    -      case StringIndexer.frequencyDesc => 
values.countByValue().toSeq.sortBy(-_._2)
    -        .map(_._1).toArray
    -      case StringIndexer.frequencyAsc => 
values.countByValue().toSeq.sortBy(_._2)
    -        .map(_._1).toArray
    -      case StringIndexer.alphabetDesc => 
values.distinct.collect.sortWith(_ > _)
    -      case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ 
< _)
    +
    +    val inputCols = getInOutCols._1
    +
    +    val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, 
Long]())
    +
    +    val countByValueArray = dataset.na.drop(inputCols)
    +      .select(inputCols.map(col(_).cast(StringType)): _*)
    +      .rdd.aggregate(zeroState)(
    +      (state: Array[OpenHashMap[String, Long]], row: Row) => {
    +        for (i <- 0 until inputCols.length) {
    +          state(i).changeValue(row.getString(i), 1L, _ + 1)
    +        }
    +        state
    +      },
    +      (state1: Array[OpenHashMap[String, Long]], state2: 
Array[OpenHashMap[String, Long]]) => {
    +        for (i <- 0 until inputCols.length) {
    +          state2(i).foreach { case (key: String, count: Long) =>
    +            state1(i).changeValue(key, count, _ + count)
    +          }
    +        }
    +        state1
    +      }
    +    )
    +    val labelsArray = countByValueArray.map { countByValue =>
    +      $(stringOrderType) match {
    +        case StringIndexer.frequencyDesc => 
countByValue.toSeq.sortBy(-_._2).map(_._1).toArray
    +        case StringIndexer.frequencyAsc => 
countByValue.toSeq.sortBy(_._2).map(_._1).toArray
    +        case StringIndexer.alphabetDesc => 
countByValue.toSeq.map(_._1).sortWith(_ > _).toArray
    +        case StringIndexer.alphabetAsc => 
countByValue.toSeq.map(_._1).sortWith(_ < _).toArray
    --- End diff --
    
    Yes, but will aggregate count bring apparent overhead ? I don't want the 
code including too many `if ..else`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to