zhengruifeng commented on code in PR #48797:
URL: https://github.com/apache/spark/pull/48797#discussion_r1833389441
##########
mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala:
##########
@@ -189,82 +189,66 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0")
override val uid: String)
validateSchema(schema, fitting = true)
}
+ private def extractLabel(name: String, targetType: String): Column = {
+ val c = col(name).cast(DoubleType)
+ targetType match {
+ case TargetEncoder.TARGET_BINARY =>
+ when(c === 0 || c === 1, c)
+ .when(c.isNull || c.isNaN, c)
+ .otherwise(raise_error(
+ concat(lit("Labels for TARGET_BINARY must be {0, 1}, but got "),
c)))
+
+ case TargetEncoder.TARGET_CONTINUOUS => c
+ }
+ }
+
+ private def extractValue(name: String): Column = {
+ val c = col(name).cast(DoubleType)
+ when(c >= 0 && c === c.cast(IntegerType), c)
+ .when(c.isNull, lit(TargetEncoder.NULL_CATEGORY))
+ .when(c.isNaN, raise_error(lit("Values MUST NOT be NaN")))
Review Comment:
I suspect we should also treat `NaN` as a valid category for missing value,
according to scikit-learn's
[doc](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html):
> TargetEncoder considers missing values, such as np.nan or None, as another
category and encodes them like any other category. Categories that are not seen
during fit are encoded with the target mean, i.e. target_mean_.
but still keep existing behavior for now
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]