This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2df08b1de737 [SPARK-50282][ML] Simplify `TargetEncoderModel.transform`
2df08b1de737 is described below
commit 2df08b1de73740293a3cbc18400823e8d8d3d15e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Nov 11 09:08:56 2024 -0800
[SPARK-50282][ML] Simplify `TargetEncoderModel.transform`
### What changes were proposed in this pull request?
Simplify `TargetEncoderModel.transform`
### Why are the changes needed?
existing implementation builds the lookup logic by a `if-else` approach of
many branches, actually we can use `try_element_at` for this purpose
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48816 from zhengruifeng/te_transform.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../apache/spark/ml/feature/TargetEncoder.scala | 79 ++++++++--------------
1 file changed, 30 insertions(+), 49 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
index 0103282c269d..d0046e3f0c5b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -77,7 +77,7 @@ private[ml] trait TargetEncoderBase extends Params with
HasLabelCol
final def getSmoothing: Double = $(smoothing)
- private[feature] lazy val inputFeatures =
+ private[feature] def inputFeatures: Array[String] =
if (isSet(inputCol)) {
Array($(inputCol))
} else if (isSet(inputCols)) {
@@ -86,7 +86,7 @@ private[ml] trait TargetEncoderBase extends Params with
HasLabelCol
Array.empty[String]
}
- private[feature] lazy val outputFeatures =
+ private[feature] def outputFeatures: Array[String] =
if (isSet(outputCol)) {
Array($(outputCol))
} else if (isSet(outputCols)) {
@@ -234,7 +234,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0")
override val uid: String)
.groupBy("index", "value")
.agg(count(lit(1)).cast(DoubleType).as("count"),
statCol.cast(DoubleType).as("stat"))
- // stats: Array[Map[category, (counter,stat)]]
+ // stats: Array[Map[category, (count, stat)]]
val stats = Array.fill(numFeatures)(collection.mutable.Map.empty[Double,
(Double, Double)])
aggregated.select("index", "value", "count", "stat").collect()
.foreach { case Row(index: Int, value: Double, count: Double, stat:
Double) =>
@@ -347,54 +347,35 @@ class TargetEncoderModel private[ml] (
}
}
- // builds a column-to-column function from a map of encodings
- val apply_encodings: Map[Double, Double] => (Column => Column) =
- (mappings: Map[Double, Double]) => {
- (col: Column) => {
- val nullWhen = when(col.isNull,
- mappings.get(TargetEncoder.NULL_CATEGORY) match {
- case Some(code) => lit(code)
- case None => if ($(handleInvalid) == TargetEncoder.KEEP_INVALID)
{
- lit(mappings.get(TargetEncoder.UNSEEN_CATEGORY).get)
- } else raise_error(lit(
- s"Unseen null value in feature ${col.toString}. To handle
unseen values, " +
- s"set Param handleInvalid to
${TargetEncoder.KEEP_INVALID}."))
- })
- val ordered_mappings = (mappings -
TargetEncoder.NULL_CATEGORY).toList.sortWith {
- (a, b) =>
- (b._1 == TargetEncoder.UNSEEN_CATEGORY) ||
- ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1 < b._1))
- }
- ordered_mappings
- .foldLeft(nullWhen)(
- (new_col: Column, mapping) => {
- val (original, encoded) = mapping
- if (original != TargetEncoder.UNSEEN_CATEGORY) {
- new_col.when(col === original, lit(encoded))
- } else { // unseen category
- new_col.otherwise(
- if ($(handleInvalid) == TargetEncoder.KEEP_INVALID)
lit(encoded)
- else raise_error(concat(
- lit("Unseen value "), col,
- lit(s" in feature ${col.toString}. To handle unseen
values, " +
- s"set Param handleInvalid to
${TargetEncoder.KEEP_INVALID}."))))
- }
- })
- }
- }
+ val newCols = inputFeatures.zip(outputFeatures).zip(encodings).map {
+ case ((featureIn, featureOut), mapping) =>
+ val unseenErrMsg = s"Unseen value %s in feature $featureIn. " +
+ s"To handle unseen values, set Param handleInvalid to
${TargetEncoder.KEEP_INVALID}."
+ val unseenErrCol = raise_error(printf(lit(unseenErrMsg),
col(featureIn).cast(StringType)))
- dataset.withColumns(
- inputFeatures.zip(outputFeatures).zip(encodings)
- .map {
- case ((featureIn, featureOut), mapping) =>
- featureOut ->
- apply_encodings(mapping)(col(featureIn))
- .as(featureOut, NominalAttribute.defaultAttr
- .withName(featureOut)
- .withNumValues(mapping.values.toSet.size)
-
.withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata())
- }.toMap)
+ val fillUnseenCol = $(handleInvalid) match {
+ case TargetEncoder.KEEP_INVALID =>
lit(mapping(TargetEncoder.UNSEEN_CATEGORY))
+ case _ => unseenErrCol
+ }
+ val fillNullCol = mapping.get(TargetEncoder.NULL_CATEGORY) match {
+ case Some(code) => lit(code)
+ case _ => fillUnseenCol
+ }
+ val filteredMapping = mapping.filter { case (k, _) =>
+ k != TargetEncoder.UNSEEN_CATEGORY && k !=
TargetEncoder.NULL_CATEGORY
+ }
+ val castedCol = col(featureIn).cast(DoubleType)
+ val targetCol = try_element_at(typedlit(filteredMapping), castedCol)
+ when(castedCol.isNull, fillNullCol)
+ .when(!targetCol.isNull, targetCol)
+ .otherwise(fillUnseenCol)
+ .as(featureOut, NominalAttribute.defaultAttr
+ .withName(featureOut)
+ .withNumValues(mapping.values.toSet.size)
+
.withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata())
+ }
+ dataset.withColumns(outputFeatures.toIndexedSeq, newCols.toIndexedSeq)
}
@Since("4.0.0")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]