This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cdbedeaebdf [SPARK-50267][ML] Improve `TargetEncoder.fit` with 
DataFrame APIs
2cdbedeaebdf is described below

commit 2cdbedeaebdfdcea2fb7844cc3608d8c8341309b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Nov 10 09:31:29 2024 +0900

    [SPARK-50267][ML] Improve `TargetEncoder.fit` with DataFrame APIs
    
    ### What changes were proposed in this pull request?
    Improve `TargetEncoder.fit` to be based on DataFrame APIs
    
    ### Why are the changes needed?
    1, simplify the implementation;
    2, with DataFrame APIs, it will benefit from the optimization from Spark SQL
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48797 from zhengruifeng/target_encoder_fit.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../apache/spark/ml/feature/TargetEncoder.scala    | 122 +++++++++------------
 .../spark/ml/feature/TargetEncoderSuite.scala      |  13 +--
 2 files changed, 57 insertions(+), 78 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
index 9afb88afec93..0103282c269d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -189,82 +189,66 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") 
override val uid: String)
     validateSchema(schema, fitting = true)
   }
 
+  private def extractLabel(name: String, targetType: String): Column = {
+    val c = col(name).cast(DoubleType)
+    targetType match {
+      case TargetEncoder.TARGET_BINARY =>
+        when(c === 0 || c === 1, c)
+          .when(c.isNull || c.isNaN, c)
+          .otherwise(raise_error(
+            concat(lit("Labels for TARGET_BINARY must be {0, 1}, but got "), 
c)))
+
+      case TargetEncoder.TARGET_CONTINUOUS => c
+    }
+  }
+
+  private def extractValue(name: String): Column = {
+    val c = col(name).cast(DoubleType)
+    when(c >= 0 && c === c.cast(IntegerType), c)
+      .when(c.isNull, lit(TargetEncoder.NULL_CATEGORY))
+      .when(c.isNaN, raise_error(lit("Values MUST NOT be NaN")))
+      .otherwise(raise_error(
+        concat(lit("Values MUST be non-negative integers, but got "), c)))
+  }
+
   @Since("4.0.0")
   override def fit(dataset: Dataset[_]): TargetEncoderModel = {
     validateSchema(dataset.schema, fitting = true)
+    val numFeatures = inputFeatures.length
 
-    // stats: Array[Map[category, (counter,stat)]]
-    val stats = dataset
-      .select((inputFeatures :+ 
$(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*)
-      .rdd.treeAggregate(
-        Array.fill(inputFeatures.length) {
-          Map.empty[Double, (Double, Double)]
-        })(
-
-        (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) {
-          val label = row.getDouble(inputFeatures.length)
-          if (!label.equals(Double.NaN)) {
-            inputFeatures.indices.map {
-              feature => {
-                val category: Double = {
-                  if (row.isNullAt(feature)) TargetEncoder.NULL_CATEGORY // 
null category
-                  else {
-                    val value = row.getDouble(feature)
-                    if (value < 0.0 || value != value.toInt) throw new 
SparkException(
-                      s"Values from column ${inputFeatures(feature)} must be 
indices, " +
-                        s"but got $value.")
-                    else value // non-null category
-                  }
-                }
-                val (class_count, class_stat) = 
agg(feature).getOrElse(category, (0.0, 0.0))
-                val (global_count, global_stat) =
-                  agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 
0.0))
-                $(targetType) match {
-                  case TargetEncoder.TARGET_BINARY => // counting
-                    if (label == 1.0) {
-                      // positive => increment both counters for current & 
unseen categories
-                      agg(feature) +
-                        (category -> (1 + class_count, 1 + class_stat)) +
-                        (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 
+ global_stat))
-                    } else if (label == 0.0) {
-                      // negative => increment only global counter for current 
& unseen categories
-                      agg(feature) +
-                        (category -> (1 + class_count, class_stat)) +
-                        (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 
global_stat))
-                    } else throw new SparkException(
-                      s"Values from column ${getLabelCol} must be binary (0,1) 
but got $label.")
-                  case TargetEncoder.TARGET_CONTINUOUS => // incremental mean
-                    // increment counter and iterate on mean for current & 
unseen categories
-                    agg(feature) +
-                      (category -> (1 + class_count,
-                        class_stat + ((label - class_stat) / (1 + 
class_count)))) +
-                      (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count,
-                        global_stat + ((label - global_stat) / (1 + 
global_count))))
-                }
-              }
-            }.toArray
-          } else agg  // ignore NaN-labeled observations
-        } else agg,  // ignore null-labeled observations
-
-        (agg1, agg2) => inputFeatures.indices.map {
-          feature => {
-            val categories = agg1(feature).keySet ++ agg2(feature).keySet
-            categories.map(category =>
-              category -> {
-                val (counter1, stat1) = agg1(feature).getOrElse(category, 
(0.0, 0.0))
-                val (counter2, stat2) = agg2(feature).getOrElse(category, 
(0.0, 0.0))
-                $(targetType) match {
-                  case TargetEncoder.TARGET_BINARY => (counter1 + counter2, 
stat1 + stat2)
-                  case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2,
-                    ((counter1 * stat1) + (counter2 * stat2)) / (counter1 + 
counter2))
-                }
-              }).toMap
-          }
-        }.toArray)
+    // Append the unseen category, for global stats computation
+    val arrayCol = array(
+      (inputFeatures.map(v => extractValue(v)) :+ 
lit(TargetEncoder.UNSEEN_CATEGORY))
+        .toIndexedSeq: _*)
+
+    val checked = dataset
+      .select(extractLabel($(labelCol), $(targetType)).as("label"), 
arrayCol.as("array"))
+      .where(!col("label").isNaN && !col("label").isNull)
+      .select(col("label"), posexplode(col("array")).as(Seq("index", "value")))
 
+    val statCol = $(targetType) match {
+      case TargetEncoder.TARGET_BINARY => count_if(col("label") === 1)
+      case TargetEncoder.TARGET_CONTINUOUS => avg(col("label"))
+    }
+    val aggregated = checked
+      .groupBy("index", "value")
+      .agg(count(lit(1)).cast(DoubleType).as("count"), 
statCol.cast(DoubleType).as("stat"))
 
+    // stats: Array[Map[category, (counter,stat)]]
+    val stats = Array.fill(numFeatures)(collection.mutable.Map.empty[Double, 
(Double, Double)])
+    aggregated.select("index", "value", "count", "stat").collect()
+      .foreach { case Row(index: Int, value: Double, count: Double, stat: 
Double) =>
+        if (index < numFeatures) {
+          // Assign the per-category stats to the corresponding feature
+          stats(index).update(value, (count, stat))
+        } else {
+          // Assign the global stats to all features
+          assert(value == TargetEncoder.UNSEEN_CATEGORY)
+          stats.foreach { s => s.update(value, (count, stat)) }
+        }
+      }
 
-    val model = new TargetEncoderModel(uid, stats).setParent(this)
+    val model = new TargetEncoderModel(uid, stats.map(_.toMap)).setParent(this)
     copyValues(model)
   }
 
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
index 869be94ff127..6bb3ce224a2e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
@@ -376,15 +376,13 @@ class TargetEncoderSuite extends MLTest with 
DefaultReadWriteTest {
     val df_noindex = spark
       .createDataFrame(sc.parallelize(data_binary :+ data_noindex), schema)
 
-    val ex = intercept[SparkException] {
+    val ex = intercept[SparkRuntimeException] {
       val model = encoder.fit(df_noindex)
       print(model.stats)
     }
 
-    assert(ex.isInstanceOf[SparkException])
     assert(ex.getMessage.contains(
-      "Values from column input3 must be indices, but got 5.1"))
-
+      "Values MUST be non-negative integers, but got 5.1"))
   }
 
   test("TargetEncoder - invalid label") {
@@ -407,7 +405,6 @@ class TargetEncoderSuite extends MLTest with 
DefaultReadWriteTest {
     model.stats.zip(expected_stats_continuous).foreach{
       case (actual, expected) => assert(actual.equals(expected))
     }
-
   }
 
   test("TargetEncoder - non-binary labels") {
@@ -423,15 +420,13 @@ class TargetEncoderSuite extends MLTest with 
DefaultReadWriteTest {
     val df_non_binary = spark
       .createDataFrame(sc.parallelize(data_binary :+ data_non_binary), schema)
 
-    val ex = intercept[SparkException] {
+    val ex = intercept[SparkRuntimeException] {
       val model = encoder.fit(df_non_binary)
       print(model.stats)
     }
 
-    assert(ex.isInstanceOf[SparkException])
     assert(ex.getMessage.contains(
-      "Values from column label must be binary (0,1) but got 2.0"))
-
+      "Labels for TARGET_BINARY must be {0, 1}, but got 2.0"))
   }
 
   test("TargetEncoder - features renamed") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to