This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2cdbedeaebdf [SPARK-50267][ML] Improve `TargetEncoder.fit` with
DataFrame APIs
2cdbedeaebdf is described below
commit 2cdbedeaebdfdcea2fb7844cc3608d8c8341309b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Nov 10 09:31:29 2024 +0900
[SPARK-50267][ML] Improve `TargetEncoder.fit` with DataFrame APIs
### What changes were proposed in this pull request?
Improve `TargetEncoder.fit` to be based on DataFrame APIs
### Why are the changes needed?
1, simplify the implementation;
2, with DataFrame APIs, it will benefit from the optimization from Spark SQL
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48797 from zhengruifeng/target_encoder_fit.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../apache/spark/ml/feature/TargetEncoder.scala | 122 +++++++++------------
.../spark/ml/feature/TargetEncoderSuite.scala | 13 +--
2 files changed, 57 insertions(+), 78 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
index 9afb88afec93..0103282c269d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -189,82 +189,66 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0")
override val uid: String)
validateSchema(schema, fitting = true)
}
+ private def extractLabel(name: String, targetType: String): Column = {
+ val c = col(name).cast(DoubleType)
+ targetType match {
+ case TargetEncoder.TARGET_BINARY =>
+ when(c === 0 || c === 1, c)
+ .when(c.isNull || c.isNaN, c)
+ .otherwise(raise_error(
+ concat(lit("Labels for TARGET_BINARY must be {0, 1}, but got "),
c)))
+
+ case TargetEncoder.TARGET_CONTINUOUS => c
+ }
+ }
+
+ private def extractValue(name: String): Column = {
+ val c = col(name).cast(DoubleType)
+ when(c >= 0 && c === c.cast(IntegerType), c)
+ .when(c.isNull, lit(TargetEncoder.NULL_CATEGORY))
+ .when(c.isNaN, raise_error(lit("Values MUST NOT be NaN")))
+ .otherwise(raise_error(
+ concat(lit("Values MUST be non-negative integers, but got "), c)))
+ }
+
@Since("4.0.0")
override def fit(dataset: Dataset[_]): TargetEncoderModel = {
validateSchema(dataset.schema, fitting = true)
+ val numFeatures = inputFeatures.length
- // stats: Array[Map[category, (counter,stat)]]
- val stats = dataset
- .select((inputFeatures :+
$(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*)
- .rdd.treeAggregate(
- Array.fill(inputFeatures.length) {
- Map.empty[Double, (Double, Double)]
- })(
-
- (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) {
- val label = row.getDouble(inputFeatures.length)
- if (!label.equals(Double.NaN)) {
- inputFeatures.indices.map {
- feature => {
- val category: Double = {
- if (row.isNullAt(feature)) TargetEncoder.NULL_CATEGORY //
null category
- else {
- val value = row.getDouble(feature)
- if (value < 0.0 || value != value.toInt) throw new
SparkException(
- s"Values from column ${inputFeatures(feature)} must be
indices, " +
- s"but got $value.")
- else value // non-null category
- }
- }
- val (class_count, class_stat) =
agg(feature).getOrElse(category, (0.0, 0.0))
- val (global_count, global_stat) =
- agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0,
0.0))
- $(targetType) match {
- case TargetEncoder.TARGET_BINARY => // counting
- if (label == 1.0) {
- // positive => increment both counters for current &
unseen categories
- agg(feature) +
- (category -> (1 + class_count, 1 + class_stat)) +
- (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1
+ global_stat))
- } else if (label == 0.0) {
- // negative => increment only global counter for current
& unseen categories
- agg(feature) +
- (category -> (1 + class_count, class_stat)) +
- (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count,
global_stat))
- } else throw new SparkException(
- s"Values from column ${getLabelCol} must be binary (0,1)
but got $label.")
- case TargetEncoder.TARGET_CONTINUOUS => // incremental mean
- // increment counter and iterate on mean for current &
unseen categories
- agg(feature) +
- (category -> (1 + class_count,
- class_stat + ((label - class_stat) / (1 +
class_count)))) +
- (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count,
- global_stat + ((label - global_stat) / (1 +
global_count))))
- }
- }
- }.toArray
- } else agg // ignore NaN-labeled observations
- } else agg, // ignore null-labeled observations
-
- (agg1, agg2) => inputFeatures.indices.map {
- feature => {
- val categories = agg1(feature).keySet ++ agg2(feature).keySet
- categories.map(category =>
- category -> {
- val (counter1, stat1) = agg1(feature).getOrElse(category,
(0.0, 0.0))
- val (counter2, stat2) = agg2(feature).getOrElse(category,
(0.0, 0.0))
- $(targetType) match {
- case TargetEncoder.TARGET_BINARY => (counter1 + counter2,
stat1 + stat2)
- case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2,
- ((counter1 * stat1) + (counter2 * stat2)) / (counter1 +
counter2))
- }
- }).toMap
- }
- }.toArray)
+ // Append the unseen category, for global stats computation
+ val arrayCol = array(
+ (inputFeatures.map(v => extractValue(v)) :+
lit(TargetEncoder.UNSEEN_CATEGORY))
+ .toIndexedSeq: _*)
+
+ val checked = dataset
+ .select(extractLabel($(labelCol), $(targetType)).as("label"),
arrayCol.as("array"))
+ .where(!col("label").isNaN && !col("label").isNull)
+ .select(col("label"), posexplode(col("array")).as(Seq("index", "value")))
+ val statCol = $(targetType) match {
+ case TargetEncoder.TARGET_BINARY => count_if(col("label") === 1)
+ case TargetEncoder.TARGET_CONTINUOUS => avg(col("label"))
+ }
+ val aggregated = checked
+ .groupBy("index", "value")
+ .agg(count(lit(1)).cast(DoubleType).as("count"),
statCol.cast(DoubleType).as("stat"))
+ // stats: Array[Map[category, (counter,stat)]]
+ val stats = Array.fill(numFeatures)(collection.mutable.Map.empty[Double,
(Double, Double)])
+ aggregated.select("index", "value", "count", "stat").collect()
+ .foreach { case Row(index: Int, value: Double, count: Double, stat:
Double) =>
+ if (index < numFeatures) {
+ // Assign the per-category stats to the corresponding feature
+ stats(index).update(value, (count, stat))
+ } else {
+ // Assign the global stats to all features
+ assert(value == TargetEncoder.UNSEEN_CATEGORY)
+ stats.foreach { s => s.update(value, (count, stat)) }
+ }
+ }
- val model = new TargetEncoderModel(uid, stats).setParent(this)
+ val model = new TargetEncoderModel(uid, stats.map(_.toMap)).setParent(this)
copyValues(model)
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
index 869be94ff127..6bb3ce224a2e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
@@ -376,15 +376,13 @@ class TargetEncoderSuite extends MLTest with
DefaultReadWriteTest {
val df_noindex = spark
.createDataFrame(sc.parallelize(data_binary :+ data_noindex), schema)
- val ex = intercept[SparkException] {
+ val ex = intercept[SparkRuntimeException] {
val model = encoder.fit(df_noindex)
print(model.stats)
}
- assert(ex.isInstanceOf[SparkException])
assert(ex.getMessage.contains(
- "Values from column input3 must be indices, but got 5.1"))
-
+ "Values MUST be non-negative integers, but got 5.1"))
}
test("TargetEncoder - invalid label") {
@@ -407,7 +405,6 @@ class TargetEncoderSuite extends MLTest with
DefaultReadWriteTest {
model.stats.zip(expected_stats_continuous).foreach{
case (actual, expected) => assert(actual.equals(expected))
}
-
}
test("TargetEncoder - non-binary labels") {
@@ -423,15 +420,13 @@ class TargetEncoderSuite extends MLTest with
DefaultReadWriteTest {
val df_non_binary = spark
.createDataFrame(sc.parallelize(data_binary :+ data_non_binary), schema)
- val ex = intercept[SparkException] {
+ val ex = intercept[SparkRuntimeException] {
val model = encoder.fit(df_non_binary)
print(model.stats)
}
- assert(ex.isInstanceOf[SparkException])
assert(ex.getMessage.contains(
- "Values from column label must be binary (0,1) but got 2.0"))
-
+ "Labels for TARGET_BINARY must be {0, 1}, but got 2.0"))
}
test("TargetEncoder - features renamed") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]