This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2df08b1de737 [SPARK-50282][ML] Simplify `TargetEncoderModel.transform`
2df08b1de737 is described below

commit 2df08b1de73740293a3cbc18400823e8d8d3d15e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Nov 11 09:08:56 2024 -0800

    [SPARK-50282][ML] Simplify `TargetEncoderModel.transform`
    
    ### What changes were proposed in this pull request?
    Simplify `TargetEncoderModel.transform`
    
    ### Why are the changes needed?
    existing implementation builds the lookup logic by a `if-else` approach of 
many branches, actually we can use `try_element_at` for this purpose
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48816 from zhengruifeng/te_transform.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../apache/spark/ml/feature/TargetEncoder.scala    | 79 ++++++++--------------
 1 file changed, 30 insertions(+), 49 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
index 0103282c269d..d0046e3f0c5b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -77,7 +77,7 @@ private[ml] trait TargetEncoderBase extends Params with 
HasLabelCol
 
   final def getSmoothing: Double = $(smoothing)
 
-  private[feature] lazy val inputFeatures =
+  private[feature] def inputFeatures: Array[String] =
     if (isSet(inputCol)) {
       Array($(inputCol))
     } else if (isSet(inputCols)) {
@@ -86,7 +86,7 @@ private[ml] trait TargetEncoderBase extends Params with 
HasLabelCol
       Array.empty[String]
     }
 
-  private[feature] lazy val outputFeatures =
+  private[feature] def outputFeatures: Array[String] =
     if (isSet(outputCol)) {
       Array($(outputCol))
     } else if (isSet(outputCols)) {
@@ -234,7 +234,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") 
override val uid: String)
       .groupBy("index", "value")
       .agg(count(lit(1)).cast(DoubleType).as("count"), 
statCol.cast(DoubleType).as("stat"))
 
-    // stats: Array[Map[category, (counter,stat)]]
+    // stats: Array[Map[category, (count, stat)]]
     val stats = Array.fill(numFeatures)(collection.mutable.Map.empty[Double, 
(Double, Double)])
     aggregated.select("index", "value", "count", "stat").collect()
       .foreach { case Row(index: Int, value: Double, count: Double, stat: 
Double) =>
@@ -347,54 +347,35 @@ class TargetEncoderModel private[ml] (
           }
       }
 
-    // builds a column-to-column function from a map of encodings
-    val apply_encodings: Map[Double, Double] => (Column => Column) =
-      (mappings: Map[Double, Double]) => {
-        (col: Column) => {
-          val nullWhen = when(col.isNull,
-            mappings.get(TargetEncoder.NULL_CATEGORY) match {
-              case Some(code) => lit(code)
-              case None => if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) 
{
-                lit(mappings.get(TargetEncoder.UNSEEN_CATEGORY).get)
-              } else raise_error(lit(
-                s"Unseen null value in feature ${col.toString}. To handle 
unseen values, " +
-                  s"set Param handleInvalid to 
${TargetEncoder.KEEP_INVALID}."))
-            })
-          val ordered_mappings = (mappings - 
TargetEncoder.NULL_CATEGORY).toList.sortWith {
-            (a, b) =>
-              (b._1 == TargetEncoder.UNSEEN_CATEGORY) ||
-                ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1 < b._1))
-          }
-          ordered_mappings
-            .foldLeft(nullWhen)(
-              (new_col: Column, mapping) => {
-                val (original, encoded) = mapping
-                if (original != TargetEncoder.UNSEEN_CATEGORY) {
-                  new_col.when(col === original, lit(encoded))
-                } else { // unseen category
-                  new_col.otherwise(
-                    if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) 
lit(encoded)
-                    else raise_error(concat(
-                      lit("Unseen value "), col,
-                      lit(s" in feature ${col.toString}. To handle unseen 
values, " +
-                        s"set Param handleInvalid to 
${TargetEncoder.KEEP_INVALID}."))))
-                }
-              })
-        }
-      }
+    val newCols = inputFeatures.zip(outputFeatures).zip(encodings).map {
+      case ((featureIn, featureOut), mapping) =>
+        val unseenErrMsg = s"Unseen value %s in feature $featureIn. " +
+          s"To handle unseen values, set Param handleInvalid to 
${TargetEncoder.KEEP_INVALID}."
+        val unseenErrCol = raise_error(printf(lit(unseenErrMsg), 
col(featureIn).cast(StringType)))
 
-    dataset.withColumns(
-      inputFeatures.zip(outputFeatures).zip(encodings)
-        .map {
-          case ((featureIn, featureOut), mapping) =>
-            featureOut ->
-              apply_encodings(mapping)(col(featureIn))
-                .as(featureOut, NominalAttribute.defaultAttr
-                  .withName(featureOut)
-                  .withNumValues(mapping.values.toSet.size)
-                  
.withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata())
-        }.toMap)
+        val fillUnseenCol = $(handleInvalid) match {
+          case TargetEncoder.KEEP_INVALID => 
lit(mapping(TargetEncoder.UNSEEN_CATEGORY))
+          case _ => unseenErrCol
+        }
+        val fillNullCol = mapping.get(TargetEncoder.NULL_CATEGORY) match {
+          case Some(code) => lit(code)
+          case _ => fillUnseenCol
+        }
+        val filteredMapping = mapping.filter { case (k, _) =>
+          k != TargetEncoder.UNSEEN_CATEGORY && k != 
TargetEncoder.NULL_CATEGORY
+        }
 
+        val castedCol = col(featureIn).cast(DoubleType)
+        val targetCol = try_element_at(typedlit(filteredMapping), castedCol)
+        when(castedCol.isNull, fillNullCol)
+          .when(!targetCol.isNull, targetCol)
+          .otherwise(fillUnseenCol)
+          .as(featureOut, NominalAttribute.defaultAttr
+            .withName(featureOut)
+            .withNumValues(mapping.values.toSet.size)
+            
.withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata())
+    }
+    dataset.withColumns(outputFeatures.toIndexedSeq, newCols.toIndexedSeq)
   }
 
   @Since("4.0.0")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to