This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 596098b5fe61 [SPARK-49223][ML] Simplify the StringIndexer.countByValue
with builtin functions
596098b5fe61 is described below
commit 596098b5fe61d5f4987d0a77156b7724a1a697f7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Aug 22 16:21:53 2024 +0800
[SPARK-49223][ML] Simplify the StringIndexer.countByValue with builtin
functions
### What changes were proposed in this pull request?
Simplify the StringIndexer.countByValue with builtin functions
### Why are the changes needed?
the StringIndexerAggregator is not necessary here
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47742 from zhengruifeng/sql_gouped_count.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../apache/spark/ml/feature/StringIndexer.scala | 72 +++++-----------------
1 file changed, 17 insertions(+), 55 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 6c10630e7bb8..72947dc17b8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -27,8 +27,7 @@ import org.apache.spark.ml.attribute.{Attribute,
NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset,
Encoder, Encoders, Row}
-import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset,
Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -201,16 +200,23 @@ class StringIndexer @Since("1.4.0") (
private def countByValue(
dataset: Dataset[_],
inputCols: Array[String]): Array[OpenHashMap[String, Long]] = {
-
- val aggregator = new StringIndexerAggregator(inputCols.length)
- implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]]
-
val selectedCols = getSelectedCols(dataset, inputCols.toImmutableArraySeq)
- dataset.select(selectedCols: _*)
- .toDF()
- .agg(aggregator.toColumn)
- .as[Array[OpenHashMap[String, Long]]]
- .collect()(0)
+ val results = Array.fill(selectedCols.size)(new OpenHashMap[String,
Long]())
+ dataset.select(posexplode(array(selectedCols: _*)).as(Seq("index",
"value")))
+ .where(col("value").isNotNull)
+ .groupBy("index", "value")
+ .agg(count(lit(1)).as("count"))
+ .groupBy("index")
+ .agg(collect_list(struct("value", "count")))
+ .collect()
+ .foreach { row =>
+ val index = row.getInt(0)
+ val result = results(index)
+ row.getSeq[Row](1).foreach { case Row(label: String, count: Long) =>
+ result.update(label, count)
+ }
+ }
+ results
}
private def sortByFreq(dataset: Dataset[_], ascending: Boolean):
Array[Array[String]] = {
@@ -642,47 +648,3 @@ object IndexToString extends
DefaultParamsReadable[IndexToString] {
@Since("1.6.0")
override def load(path: String): IndexToString = super.load(path)
}
-
-/**
- * A SQL `Aggregator` used by `StringIndexer` to count labels in string
columns during fitting.
- */
-private class StringIndexerAggregator(numColumns: Int)
- extends Aggregator[Row, Array[OpenHashMap[String, Long]],
Array[OpenHashMap[String, Long]]] {
-
- override def zero: Array[OpenHashMap[String, Long]] =
- Array.fill(numColumns)(new OpenHashMap[String, Long]())
-
- def reduce(
- array: Array[OpenHashMap[String, Long]],
- row: Row): Array[OpenHashMap[String, Long]] = {
- for (i <- 0 until numColumns) {
- val stringValue = row.getString(i)
- // We don't count for null values.
- if (stringValue != null) {
- array(i).changeValue(stringValue, 1L, _ + 1)
- }
- }
- array
- }
-
- def merge(
- array1: Array[OpenHashMap[String, Long]],
- array2: Array[OpenHashMap[String, Long]]): Array[OpenHashMap[String,
Long]] = {
- for (i <- 0 until numColumns) {
- array2(i).foreach { case (key: String, count: Long) =>
- array1(i).changeValue(key, count, _ + count)
- }
- }
- array1
- }
-
- def finish(array: Array[OpenHashMap[String, Long]]):
Array[OpenHashMap[String, Long]] = array
-
- override def bufferEncoder: Encoder[Array[OpenHashMap[String, Long]]] = {
- Encoders.kryo[Array[OpenHashMap[String, Long]]]
- }
-
- override def outputEncoder: Encoder[Array[OpenHashMap[String, Long]]] = {
- Encoders.kryo[Array[OpenHashMap[String, Long]]]
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]