This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a523d66755c0 [SPARK-51963][ML] Simplify IndexToString.transform
a523d66755c0 is described below

commit a523d66755c09b42dbcac2d304363ffe7c715ed6
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu May 1 07:42:35 2025 +0800

    [SPARK-51963][ML] Simplify IndexToString.transform
    
    ### What changes were proposed in this pull request?
    Simplify IndexToString.transform
    
    ### Why are the changes needed?
    the logic is pretty simple, we should use built-in functions instead of udf
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50767 from zhengruifeng/ml_sql_string_ind.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../org/apache/spark/ml/feature/StringIndexer.scala    | 18 ++++++------------
 1 file changed, 6 insertions(+), 12 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 6518b0d9cf92..30b8c813188f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -28,7 +28,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset}
-import org.apache.spark.sql.functions._
+import org.apache.spark.sql.functions.{get => fget, printf => fprintf, _}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.VersionUtils.majorMinorVersion
@@ -586,17 +586,11 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") 
override val uid: String)
     } else {
       $(labels)
     }
-    val indexer = udf { index: Double =>
-      val idx = index.toInt
-      if (0 <= idx && idx < values.length) {
-        values(idx)
-      } else {
-        throw new SparkException(s"Unseen index: $index ??")
-      }
-    }
-    val outputColName = $(outputCol)
-    dataset.select(col("*"),
-      indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
+
+    val idxCol = col($(inputCol)).cast(IntegerType)
+    val valCol = when(lit(0) <= idxCol && idxCol < lit(values.length), 
fget(lit(values), idxCol))
+      .otherwise(raise_error(fprintf(lit("Unseen index: %s ??"), 
idxCol.cast(StringType))))
+    dataset.withColumn($(outputCol), valCol)
   }
 
   @Since("1.5.0")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to