This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 596098b5fe61 [SPARK-49223][ML] Simplify the StringIndexer.countByValue 
with builtin functions
596098b5fe61 is described below

commit 596098b5fe61d5f4987d0a77156b7724a1a697f7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Aug 22 16:21:53 2024 +0800

    [SPARK-49223][ML] Simplify the StringIndexer.countByValue with builtin 
functions
    
    ### What changes were proposed in this pull request?
    Simplify the StringIndexer.countByValue with builtin functions
    
    ### Why are the changes needed?
    the StringIndexerAggregator is not necessary here
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47742 from zhengruifeng/sql_gouped_count.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../apache/spark/ml/feature/StringIndexer.scala    | 72 +++++-----------------
 1 file changed, 17 insertions(+), 55 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 6c10630e7bb8..72947dc17b8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -27,8 +27,7 @@ import org.apache.spark.ml.attribute.{Attribute, 
NominalAttribute}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, 
Encoder, Encoders, Row}
-import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, 
Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.ArrayImplicits._
@@ -201,16 +200,23 @@ class StringIndexer @Since("1.4.0") (
   private def countByValue(
       dataset: Dataset[_],
       inputCols: Array[String]): Array[OpenHashMap[String, Long]] = {
-
-    val aggregator = new StringIndexerAggregator(inputCols.length)
-    implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]]
-
     val selectedCols = getSelectedCols(dataset, inputCols.toImmutableArraySeq)
-    dataset.select(selectedCols: _*)
-      .toDF()
-      .agg(aggregator.toColumn)
-      .as[Array[OpenHashMap[String, Long]]]
-      .collect()(0)
+    val results = Array.fill(selectedCols.size)(new OpenHashMap[String, 
Long]())
+    dataset.select(posexplode(array(selectedCols: _*)).as(Seq("index", 
"value")))
+      .where(col("value").isNotNull)
+      .groupBy("index", "value")
+      .agg(count(lit(1)).as("count"))
+      .groupBy("index")
+      .agg(collect_list(struct("value", "count")))
+      .collect()
+      .foreach { row =>
+        val index = row.getInt(0)
+        val result = results(index)
+        row.getSeq[Row](1).foreach { case Row(label: String, count: Long) =>
+          result.update(label, count)
+        }
+      }
+    results
   }
 
   private def sortByFreq(dataset: Dataset[_], ascending: Boolean): 
Array[Array[String]] = {
@@ -642,47 +648,3 @@ object IndexToString extends 
DefaultParamsReadable[IndexToString] {
   @Since("1.6.0")
   override def load(path: String): IndexToString = super.load(path)
 }
-
-/**
- * A SQL `Aggregator` used by `StringIndexer` to count labels in string 
columns during fitting.
- */
-private class StringIndexerAggregator(numColumns: Int)
-  extends Aggregator[Row, Array[OpenHashMap[String, Long]], 
Array[OpenHashMap[String, Long]]] {
-
-  override def zero: Array[OpenHashMap[String, Long]] =
-    Array.fill(numColumns)(new OpenHashMap[String, Long]())
-
-  def reduce(
-      array: Array[OpenHashMap[String, Long]],
-      row: Row): Array[OpenHashMap[String, Long]] = {
-    for (i <- 0 until numColumns) {
-      val stringValue = row.getString(i)
-      // We don't count for null values.
-      if (stringValue != null) {
-        array(i).changeValue(stringValue, 1L, _ + 1)
-      }
-    }
-    array
-  }
-
-  def merge(
-      array1: Array[OpenHashMap[String, Long]],
-      array2: Array[OpenHashMap[String, Long]]): Array[OpenHashMap[String, 
Long]] = {
-    for (i <- 0 until numColumns) {
-      array2(i).foreach { case (key: String, count: Long) =>
-        array1(i).changeValue(key, count, _ + count)
-      }
-    }
-    array1
-  }
-
-  def finish(array: Array[OpenHashMap[String, Long]]): 
Array[OpenHashMap[String, Long]] = array
-
-  override def bufferEncoder: Encoder[Array[OpenHashMap[String, Long]]] = {
-    Encoders.kryo[Array[OpenHashMap[String, Long]]]
-  }
-
-  override def outputEncoder: Encoder[Array[OpenHashMap[String, Long]]] = {
-    Encoders.kryo[Array[OpenHashMap[String, Long]]]
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to