srowen commented on a change in pull request #20146: [SPARK-11215][ML] Add 
multiple columns support to StringIndexer
URL: https://github.com/apache/spark/pull/20146#discussion_r247541378
 
 

 ##########
 File path: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
 ##########
 @@ -172,37 +274,78 @@ object StringIndexer extends 
DefaultParamsReadable[StringIndexer] {
 
   @Since("1.6.0")
   override def load(path: String): StringIndexer = super.load(path)
+
+  // Returns a function used to sort strings by frequency (ascending or 
descending).
+  // In case of equal frequency, it sorts strings by alphabet (ascending).
+  private[feature] def getSortFunc(
+      ascending: Boolean): ((String, Long), (String, Long)) => Boolean = {
+    if (ascending) {
+      { case ((strA: String, freqA: Long), (strB: String, freqB: Long)) =>
+        if (freqA == freqB) {
+         strA < strB
+        } else {
+         freqA < freqB
+        }
+      }
+    } else {
+      { case ((strA: String, freqA: Long), (strB: String, freqB: Long)) =>
+        if (freqA == freqB) {
+          strA  < strB
+        } else {
+          freqA > freqB
+        }
+      }
+    }
+  }
 }
 
 /**
  * Model fitted by [[StringIndexer]].
  *
- * @param labels  Ordered list of labels, corresponding to indices to be 
assigned.
+ * @param labelsArray Array of ordered list of labels, corresponding to 
indices to be assigned
+ *                    for each input column.
  *
- * @note During transformation, if the input column does not exist,
- * `StringIndexerModel.transform` would return the input dataset unmodified.
+ * @note During transformation, if any input column does not exist,
+ * `StringIndexerModel.transform` would skip the input column.
+ * If all input columns do not exist, it returns the input dataset unmodified.
  * This is a temporary fix for the case when target labels do not exist during 
prediction.
  */
 @Since("1.4.0")
 class StringIndexerModel (
     @Since("1.4.0") override val uid: String,
-    @Since("1.5.0") val labels: Array[String])
+    @Since("3.0.0") val labelsArray: Array[Array[String]])
   extends Model[StringIndexerModel] with StringIndexerBase with MLWritable {
 
   import StringIndexerModel._
 
   @Since("1.5.0")
-  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), 
labels)
-
-  private val labelToIndex: OpenHashMap[String, Double] = {
-    val n = labels.length
-    val map = new OpenHashMap[String, Double](n)
-    var i = 0
-    while (i < n) {
-      map.update(labels(i), i)
-      i += 1
+  def this(uid: String, labels: Array[String]) = this(uid, Array(labels))
+
+  @Since("1.5.0")
+  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), 
Array(labels))
+
+  @Since("3.0.0")
+  def this(labelsArray: Array[Array[String]]) = 
this(Identifiable.randomUID("strIdx"), labelsArray)
+
+  @deprecated("`labels` is deprecated and will be removed in 3.1.0. Use 
`labelsArray` " +
+    "instead.", "3.0.0")
+  @Since("1.5.0")
+  def labels: Array[String] = {
+    require(labelsArray.length == 1, "This StringIndexerModel is fitted by 
multi-columns, " +
 
 Review comment:
   Nit: the description needs some rewording. "This StringIndexerModel is fit 
on multiple columns. Call `labelArray` insetad."

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to