This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new abe990e58eee [SPARK-37178][ML] Add Target Encoding to ml.feature
abe990e58eee is described below
commit abe990e58eeecfba177bd5ce1cce25872de074c9
Author: Enrique Rebollo <[email protected]>
AuthorDate: Wed Nov 6 15:48:01 2024 -0800
[SPARK-37178][ML] Add Target Encoding to ml.feature
### What changes were proposed in this pull request?
Adds support for target encoding of ml features.
Target Encoding maps a column of categorical indices into a numerical
feature derived from the target.
Leveraging the relationship between categorical variables and the target
variable, target encoding usually performs better than one-hot encoding (while
avoiding the need to add extra columns)
### Why are the changes needed?
Target Encoding is a well-known encoding technique for categorical features.
It's supported on most ml frameworks
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html
https://search.r-project.org/CRAN/refmans/dataPreparation/html/target_encode.html
### Does this PR introduce _any_ user-facing change?
Spark API now includes 2 new classes in package org.apache.spark.ml
- TargetEncoder (estimator)
- TargetEncoderModel (transformer)
### How was this patch tested?
Scala => org.apache.spark.ml.feature.TargetEncoderSuite
Java => org.apache.spark.ml.feature.JavaTargetEncoderSuite
Python => python.pyspark.ml.tests.test_feature.FeatureTests (added 2 tests)
### Was this patch authored or co-authored using generative AI tooling?
No
### Some design notes ... |-
- binary and continuous target types (no multi-label yet)
- available in Scala, Java and Python APIs
- fitting implemented on RDD API (treeAggregate)
- transformation implemented on Dataframe API (no UDFs)
- categorical features must be indices (integers) in Double-typed columns
(as if StringIndexer were used before)
- unseen categories in training are represented as class -1.0
- <b>Encodings structure</b>
- Map[String, Map[Double, Double]]) => Map[ feature_name, Map[
original_category, encoded category ] ]
- <b>Parameters</b>
- inputCol(s) / outputCol(s) / labelCol => as usual
- targetType
- binary => encodings calculated as in-category conditional
probability (counting)
- continuous => encodings calculated as in-category target mean
(incrementally)
- handleInvalid
- error => raises an error if trying to encode an unseen category
- keep => encodes an unseen category with the overall statistics
- smoothing => controls how in-category stats and overall stats are
weighted to calculate final encodings (to avoid overfitting)
Closes #48347 from rebo16v/sparkml-target-encoding.
Lead-authored-by: Enrique Rebollo <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
docs/ml-features.md | 110 +++++
.../examples/ml/JavaTargetEncoderExample.java | 90 ++++
.../src/main/python/ml/target_encoder_example.py | 65 +++
.../spark/examples/ml/TargetEncoderExample.scala | 71 +++
.../apache/spark/ml/feature/TargetEncoder.scala | 474 ++++++++++++++++++
.../spark/ml/feature/JavaTargetEncoderSuite.java | 157 ++++++
.../spark/ml/feature/TargetEncoderSuite.scala | 529 +++++++++++++++++++++
python/docs/source/reference/pyspark.ml.rst | 2 +
python/pyspark/ml/feature.py | 301 ++++++++++++
python/pyspark/ml/tests/test_feature.py | 166 +++++++
10 files changed, 1965 insertions(+)
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 3dbb960dea03..418e94ad1ea1 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -855,6 +855,116 @@ for more details on the API.
</div>
+## TargetEncoder
+
+[Target
Encoding](https://www.researchgate.net/publication/220520258_A_Preprocessing_Scheme_for_High-Cardinality_Categorical_Attributes_in_Classification_and_Prediction_Problems)
is a data-preprocessing technique that transforms high-cardinality categorical
features into quasi-continuous scalar attributes suited for use in
regression-type models. This paradigm maps individual values of an independent
feature to a scalar, representing some estimate of the dependent attribute
(meaning cate [...]
+
+By leveraging the relationship between categorical features and the target
variable, Target Encoding usually performs better than One-Hot and does not
require a final binary vector encoding, decreasing the overall dimensionality
of the dataset.
+
+User can specify input and output column names by setting `inputCol` and
`outputCol` for single-column use cases, or `inputCols` and `outputCols` for
multi-column use cases (both arrays required to have the same size). These
columns are expected to contain categorical indices (positive integers), being
missing values (null) treated as a separate category. Data type must be any
subclass of 'NumericType'. For string type input data, it is common to encode
categorical features using [String [...]
+
+User can specify the target column name by setting `label`. This column is
expected to contain the ground-truth labels from which encodings will be
derived. Observations with missing label (null) are not considered when
calculating estimates. Data type must be any subclass of 'NumericType'.
+
+`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle
invalid input, meaning categories not seen at training, when encoding new data.
Available options include 'keep' (any invalid inputs are assigned to an extra
categorical index) and 'error' (throw an exception).
+
+`TargetEncoder` supports the `targetType` parameter to choose the label type
when fitting data, affecting how estimates are calculated. Available options
include 'binary' and 'continuous'.
+
+When set to 'binary', the target attribute $Y$ is expected to be binary,
$Y\in\{ 0,1 \}$. The transformation maps individual values $X_{i}$ to the
conditional probability of $Y$ given that $X=X_{i}\;$: $\;\; S_{i}=P(Y\mid
X=X_{i})$. This approach is also known as bin-counting.
+
+When set to 'continuous', the target attribute $Y$ is expected to be
continuous, $Y\in\mathbb{Q}$. The transformation maps individual values $X_{i}$
to the average of $Y$ given that $X=X_{i}\;$: $\;\; S_{i}=E[Y\mid X=X_{i}]$.
This approach is also known as mean-encoding.
+
+`TargetEncoder` supports the `smoothing` parameter to tune how in-category
stats and overall stats are blended. High-cardinality categorical features are
usually unevenly distributed across all possible values of $X$.
+Therefore, calculating encodings $S_{i}$ according only to in-class statistics
makes this estimates very unreliable, and rarely seen categories will very
likely cause overfitting in learning.
+
+Smoothing prevents this behaviour by weighting in-class estimates with overall
estimates according to the relative size of the particular class on the whole
dataset.
+
+$\;\;\; S_{i}=\lambda(n_{i})\, P(Y\mid X=X_{i})+(1-\lambda(n_{i}))\, P(Y)$ for
the binary case
+
+$\;\;\; S_{i}=\lambda(n_{i})\, E[Y\mid X=X_{i}]+(1-\lambda(n_{i}))\, E[Y]$ for
the continuous case
+
+being $\lambda(n_{i})$ a monotonically increasing function on $n_{i}$, bounded
between 0 and 1.
+
+Usually $\lambda(n_{i})$ is implemented as the parametric function
$\lambda(n_{i})=\frac{n_{i}}{n_{i}+m}$, where $m$ is the smoothing factor,
represented by `smoothing` parameter in `TargetEncoder`.
+
+**Examples**
+
+Building on the `TargetEncoder` example, let's assume we have the following
+DataFrame with columns `feature` and `target` (binary & continuous):
+
+~~~~
+ feature | target | target
+ | (bin) | (cont)
+ --------|--------|--------
+ 1 | 0 | 1.3
+ 1 | 1 | 2.5
+ 1 | 0 | 1.6
+ 2 | 1 | 1.8
+ 2 | 0 | 2.4
+ 3 | 1 | 3.2
+~~~~
+
+Applying `TargetEncoder` with 'binary' target type,
+`feature` as the input column,`target (bin)` as the label column
+and `encoded` as the output column, we are able to fit a model
+on the data to learn encodings and transform the data according
+to these mappings:
+
+~~~~
+ feature | target | encoded
+ | (bin) |
+ --------|--------|--------
+ 1 | 0 | 0.333
+ 1 | 1 | 0.333
+ 1 | 0 | 0.333
+ 2 | 1 | 0.5
+ 2 | 0 | 0.5
+ 3 | 1 | 1.0
+~~~~
+
+Applying `TargetEncoder` with 'continuous' target type,
+`feature` as the input column,`target (cont)` as the label column
+and `encoded` as the output column, we are able to fit a model
+on the data to learn encodings and transform the data according
+to these mappings:
+
+~~~~
+ feature | target | encoded
+ | (cont) |
+ --------|--------|--------
+ 1 | 1.3 | 1.8
+ 1 | 2.5 | 1.8
+ 1 | 1.6 | 1.8
+ 2 | 1.8 | 2.1
+ 2 | 2.4 | 2.1
+ 3 | 3.2 | 3.2
+~~~~
+
+<div class="codetabs">
+
+<div data-lang="python" markdown="1">
+
+Refer to the [TargetEncoder Python
docs](api/python/reference/api/pyspark.ml.feature.TargetEncoder.html) for more
details on the API.
+
+{% include_example python/ml/target_encoder_example.py %}
+</div>
+
+<div data-lang="scala" markdown="1">
+
+Refer to the [TargetEncoder Scala
docs](api/scala/org/apache/spark/ml/feature/TargetEncoder.html) for more
details on the API.
+
+{% include_example
scala/org/apache/spark/examples/ml/TargetEncoderExample.scala %}
+</div>
+
+<div data-lang="java" markdown="1">
+
+Refer to the [TargetEncoder Java
docs](api/java/org/apache/spark/ml/feature/TargetEncoder.html)
+for more details on the API.
+
+{% include_example
java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java %}
+</div>
+
+</div>
+
## VectorIndexer
`VectorIndexer` helps index categorical features in datasets of `Vector`s.
diff --git
a/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java
b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java
new file mode 100644
index 000000000000..460f0d5a51e6
--- /dev/null
+++
b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.sql.SparkSession;
+
+// $example on$
+import org.apache.spark.ml.feature.TargetEncoder;
+import org.apache.spark.ml.feature.TargetEncoderModel;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.Arrays;
+import java.util.List;
+// $example off$
+
+public class JavaTargetEncoderExample {
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaTargetEncoderExample")
+ .getOrCreate();
+
+ // Note: categorical features are usually first encoded with StringIndexer
+ // $example on$
+ List<Row> data = Arrays.asList(
+ RowFactory.create(0.0, 1.0, 0, 10.0),
+ RowFactory.create(1.0, 0.0, 1, 20.0),
+ RowFactory.create(2.0, 1.0, 0, 30.0),
+ RowFactory.create(0.0, 2.0, 1, 40.0),
+ RowFactory.create(0.0, 1.0, 0, 50.0),
+ RowFactory.create(2.0, 0.0, 1, 60.0)
+ );
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("categoryIndex1", DataTypes.DoubleType, false,
Metadata.empty()),
+ new StructField("categoryIndex2", DataTypes.DoubleType, false,
Metadata.empty()),
+ new StructField("binaryLabel", DataTypes.DoubleType, false,
Metadata.empty()),
+ new StructField("continuousLabel", DataTypes.DoubleType, false,
Metadata.empty())
+ });
+
+ Dataset<Row> df = spark.createDataFrame(data, schema);
+
+ // binary target
+ TargetEncoder bin_encoder = new TargetEncoder()
+ .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"})
+ .setOutputCols(new String[] {"categoryIndex1Target",
"categoryIndex2Target"})
+ .setLabelCol("binaryLabel")
+ .setTargetType("binary");
+
+ TargetEncoderModel bin_model = bin_encoder.fit(df);
+ Dataset<Row> bin_encoded = bin_model.transform(df);
+ bin_encoded.show();
+
+ // continuous target
+ TargetEncoder cont_encoder = new TargetEncoder()
+ .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"})
+ .setOutputCols(new String[] {"categoryIndex1Target",
"categoryIndex2Target"})
+ .setLabelCol("continuousLabel")
+ .setTargetType("continuous");
+
+ TargetEncoderModel cont_model = cont_encoder.fit(df);
+ Dataset<Row> cont_encoded = cont_model.transform(df);
+ cont_encoded.show();
+ // $example off$
+
+ spark.stop();
+ }
+}
+
diff --git a/examples/src/main/python/ml/target_encoder_example.py
b/examples/src/main/python/ml/target_encoder_example.py
new file mode 100644
index 000000000000..f6c1010de71f
--- /dev/null
+++ b/examples/src/main/python/ml/target_encoder_example.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# $example on$
+from pyspark.ml.feature import TargetEncoder
+
+# $example off$
+from pyspark.sql import SparkSession
+
+if __name__ == "__main__":
+ spark = SparkSession.builder.appName("TargetEncoderExample").getOrCreate()
+
+ # Note: categorical features are usually first encoded with StringIndexer
+ # $example on$
+ df = spark.createDataFrame(
+ [
+ (0.0, 1.0, 0, 10.0),
+ (1.0, 0.0, 1, 20.0),
+ (2.0, 1.0, 0, 30.0),
+ (0.0, 2.0, 1, 40.0),
+ (0.0, 1.0, 0, 50.0),
+ (2.0, 0.0, 1, 60.0),
+ ],
+ ["categoryIndex1", "categoryIndex2", "binaryLabel", "continuousLabel"],
+ )
+
+ # binary target
+ encoder = TargetEncoder(
+ inputCols=["categoryIndex1", "categoryIndex2"],
+ outputCols=["categoryIndex1Target", "categoryIndex2Target"],
+ labelCol="binaryLabel",
+ targetType="binary"
+ )
+ model = encoder.fit(df)
+ encoded = model.transform(df)
+ encoded.show()
+
+ # continuous target
+ encoder = TargetEncoder(
+ inputCols=["categoryIndex1", "categoryIndex2"],
+ outputCols=["categoryIndex1Target", "categoryIndex2Target"],
+ labelCol="continuousLabel",
+ targetType="continuous"
+ )
+
+ model = encoder.fit(df)
+ encoded = model.transform(df)
+ encoded.show()
+ # $example off$
+
+ spark.stop()
diff --git
a/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala
b/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala
new file mode 100644
index 000000000000..a03f903c86d0
--- /dev/null
+++
b/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.feature.TargetEncoder
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+object TargetEncoderExample {
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder()
+ .appName("TargetEncoderExample")
+ .getOrCreate()
+
+ // Note: categorical features are usually first encoded with StringIndexer
+ // $example on$
+ val df = spark.createDataFrame(Seq(
+ (0.0, 1.0, 0, 10.0),
+ (1.0, 0.0, 1, 20.0),
+ (2.0, 1.0, 0, 30.0),
+ (0.0, 2.0, 1, 40.0),
+ (0.0, 1.0, 0, 50.0),
+ (2.0, 0.0, 1, 60.0)
+ )).toDF("categoryIndex1", "categoryIndex2",
+ "binaryLabel", "continuousLabel")
+
+ // binary target
+ val bin_encoder = new TargetEncoder()
+ .setInputCols(Array("categoryIndex1", "categoryIndex2"))
+ .setOutputCols(Array("categoryIndex1Target", "categoryIndex2Target"))
+ .setLabelCol("binaryLabel")
+ .setTargetType("binary");
+
+ val bin_model = bin_encoder.fit(df)
+ val bin_encoded = bin_model.transform(df)
+ bin_encoded.show()
+
+ // continuous target
+ val cont_encoder = new TargetEncoder()
+ .setInputCols(Array("categoryIndex1", "categoryIndex2"))
+ .setOutputCols(Array("categoryIndex1Target", "categoryIndex2Target"))
+ .setLabelCol("continuousLabel")
+ .setTargetType("continuous");
+
+ val cont_model = cont_encoder.fit(df)
+ val cont_encoded = cont_model.transform(df)
+ cont_encoded.show()
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
new file mode 100644
index 000000000000..9afb88afec93
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -0,0 +1,474 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/** Private trait for params and common methods for TargetEncoder and
TargetEncoderModel */
+private[ml] trait TargetEncoderBase extends Params with HasLabelCol
+ with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols with
HasHandleInvalid {
+
+ /**
+ * Param for how to handle invalid data during transform().
+ * Options are 'keep' (invalid data presented as an extra categorical
feature) or
+ * 'error' (throw an error).
+ * Note that this Param is only used during transform; during fitting,
invalid data
+ * will result in an error.
+ * Default: "error"
+ * @group param
+ */
+ @Since("4.0.0")
+ override val handleInvalid: Param[String] = new Param[String](this,
"handleInvalid",
+ "How to handle invalid data during transform(). " +
+ "Options are 'keep' (invalid data presented as an extra categorical
feature) " +
+ "or 'error' (throw an error). Note that this Param is only used during
transform; " +
+ "during fitting, invalid data will result in an error.",
+ ParamValidators.inArray(TargetEncoder.supportedHandleInvalids))
+
+ setDefault(handleInvalid -> TargetEncoder.ERROR_INVALID)
+
+ @Since("4.0.0")
+ val targetType: Param[String] = new Param[String](this, "targetType",
+ "Type of label considered during fit(). " +
+ "Options are 'binary' and 'continuous'. When 'binary', estimates are
calculated as " +
+ "conditional probability of the target given each category. When
'continuous', " +
+ "estimates are calculated as the average of the target given each
category" +
+ "Note that this Param is only used during fitting.",
+ ParamValidators.inArray(TargetEncoder.supportedTargetTypes))
+
+ setDefault(targetType -> TargetEncoder.TARGET_BINARY)
+
+ final def getTargetType: String = $(targetType)
+
+ @Since("4.0.0")
+ val smoothing: DoubleParam = new DoubleParam(this, "smoothing",
+ "Smoothing factor for encodings. Smoothing blends in-class estimates with
overall estimates " +
+ "according to the relative size of the particular class on the whole
dataset, reducing the " +
+ "risk of overfitting due to unreliable estimates",
+ ParamValidators.gtEq(0.0))
+
+ setDefault(smoothing -> 0.0)
+
+ final def getSmoothing: Double = $(smoothing)
+
+ private[feature] lazy val inputFeatures =
+ if (isSet(inputCol)) {
+ Array($(inputCol))
+ } else if (isSet(inputCols)) {
+ $(inputCols)
+ } else {
+ Array.empty[String]
+ }
+
+ private[feature] lazy val outputFeatures =
+ if (isSet(outputCol)) {
+ Array($(outputCol))
+ } else if (isSet(outputCols)) {
+ $(outputCols)
+ } else {
+ inputFeatures.map{field: String => s"${field}_indexed"}
+ }
+
+ private[feature] def validateSchema(schema: StructType, fitting: Boolean):
StructType = {
+
+ require(inputFeatures.length > 0,
+ s"At least one input column must be specified.")
+
+ require(inputFeatures.length == outputFeatures.length,
+ s"The number of input columns ${inputFeatures.length} must be the same
as the number of " +
+ s"output columns ${outputFeatures.length}.")
+
+ val features = if (fitting) inputFeatures :+ $(labelCol)
+ else inputFeatures
+
+ features.foreach {
+ feature => {
+ try {
+ val field = schema(feature)
+ if (!field.dataType.isInstanceOf[NumericType]) {
+ throw new SparkException(s"Data type for column ${feature} is
${field.dataType}" +
+ s", but a subclass of ${NumericType} is required.")
+ }
+ } catch {
+ case e: IllegalArgumentException =>
+ throw new SparkException(s"No column named ${feature} found on
dataset.")
+ }
+ }
+ }
+ schema
+ }
+
+}
+
+/**
+ * Target Encoding maps a column of categorical indices into a numerical
feature derived
+ * from the target.
+ *
+ * When `handleInvalid` is configured to 'keep', previously unseen values of a
feature
+ * are mapped to the dataset overall statistics.
+ *
+ * When 'targetType' is configured to 'binary', categories are encoded as the
conditional
+ * probability of the target given that category (bin counting).
+ * When 'targetType' is configured to 'continuous', categories are encoded as
the average
+ * of the target given that category (mean encoding)
+ *
+ * Parameter 'smoothing' controls how in-category stats and overall stats are
weighted.
+ *
+ * @note When encoding multi-column by using `inputCols` and `outputCols`
params, input/output cols
+ * come in pairs, specified by the order in the arrays, and each pair is
treated independently.
+ *
+ * @see `StringIndexer` for converting categorical values into category indices
+ */
+@Since("4.0.0")
+class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String)
+ extends Estimator[TargetEncoderModel] with TargetEncoderBase with
DefaultParamsWritable {
+
+ @Since("4.0.0")
+ def this() = this(Identifiable.randomUID("TargetEncoder"))
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setInputCols(values: Array[String]): this.type = set(inputCols, values)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setTargetType(value: String): this.type = set(targetType, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setSmoothing(value: Double): this.type = set(smoothing, value)
+
+ @Since("4.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateSchema(schema, fitting = true)
+ }
+
+ @Since("4.0.0")
+ override def fit(dataset: Dataset[_]): TargetEncoderModel = {
+ validateSchema(dataset.schema, fitting = true)
+
+ // stats: Array[Map[category, (counter,stat)]]
+ val stats = dataset
+ .select((inputFeatures :+
$(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*)
+ .rdd.treeAggregate(
+ Array.fill(inputFeatures.length) {
+ Map.empty[Double, (Double, Double)]
+ })(
+
+ (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) {
+ val label = row.getDouble(inputFeatures.length)
+ if (!label.equals(Double.NaN)) {
+ inputFeatures.indices.map {
+ feature => {
+ val category: Double = {
+ if (row.isNullAt(feature)) TargetEncoder.NULL_CATEGORY //
null category
+ else {
+ val value = row.getDouble(feature)
+ if (value < 0.0 || value != value.toInt) throw new
SparkException(
+ s"Values from column ${inputFeatures(feature)} must be
indices, " +
+ s"but got $value.")
+ else value // non-null category
+ }
+ }
+ val (class_count, class_stat) =
agg(feature).getOrElse(category, (0.0, 0.0))
+ val (global_count, global_stat) =
+ agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0,
0.0))
+ $(targetType) match {
+ case TargetEncoder.TARGET_BINARY => // counting
+ if (label == 1.0) {
+ // positive => increment both counters for current &
unseen categories
+ agg(feature) +
+ (category -> (1 + class_count, 1 + class_stat)) +
+ (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1
+ global_stat))
+ } else if (label == 0.0) {
+ // negative => increment only global counter for current
& unseen categories
+ agg(feature) +
+ (category -> (1 + class_count, class_stat)) +
+ (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count,
global_stat))
+ } else throw new SparkException(
+ s"Values from column ${getLabelCol} must be binary (0,1)
but got $label.")
+ case TargetEncoder.TARGET_CONTINUOUS => // incremental mean
+ // increment counter and iterate on mean for current &
unseen categories
+ agg(feature) +
+ (category -> (1 + class_count,
+ class_stat + ((label - class_stat) / (1 +
class_count)))) +
+ (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count,
+ global_stat + ((label - global_stat) / (1 +
global_count))))
+ }
+ }
+ }.toArray
+ } else agg // ignore NaN-labeled observations
+ } else agg, // ignore null-labeled observations
+
+ (agg1, agg2) => inputFeatures.indices.map {
+ feature => {
+ val categories = agg1(feature).keySet ++ agg2(feature).keySet
+ categories.map(category =>
+ category -> {
+ val (counter1, stat1) = agg1(feature).getOrElse(category,
(0.0, 0.0))
+ val (counter2, stat2) = agg2(feature).getOrElse(category,
(0.0, 0.0))
+ $(targetType) match {
+ case TargetEncoder.TARGET_BINARY => (counter1 + counter2,
stat1 + stat2)
+ case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2,
+ ((counter1 * stat1) + (counter2 * stat2)) / (counter1 +
counter2))
+ }
+ }).toMap
+ }
+ }.toArray)
+
+
+
+ val model = new TargetEncoderModel(uid, stats).setParent(this)
+ copyValues(model)
+ }
+
+ @Since("4.0.0")
+ override def copy(extra: ParamMap): TargetEncoder = defaultCopy(extra)
+}
+
+@Since("4.0.0")
+object TargetEncoder extends DefaultParamsReadable[TargetEncoder] {
+
+ // handleInvalid parameter values
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val supportedHandleInvalids: Array[String] =
Array(KEEP_INVALID, ERROR_INVALID)
+
+ // targetType parameter values
+ private[feature] val TARGET_BINARY: String = "binary"
+ private[feature] val TARGET_CONTINUOUS: String = "continuous"
+ private[feature] val supportedTargetTypes: Array[String] =
Array(TARGET_BINARY, TARGET_CONTINUOUS)
+
+ private[feature] val UNSEEN_CATEGORY: Double = Int.MaxValue
+ private[feature] val NULL_CATEGORY: Double = -1
+
+ @Since("4.0.0")
+ override def load(path: String): TargetEncoder = super.load(path)
+}
+
+/**
+ * @param stats Array of statistics for each input feature.
+ * Array( Map( category, (counter, stat) ) )
+ */
+@Since("4.0.0")
+class TargetEncoderModel private[ml] (
+ @Since("4.0.0") override val uid: String,
+ @Since("4.0.0") val stats: Array[Map[Double, (Double,
Double)]])
+ extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable {
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setInputCols(values: Array[String]): this.type = set(inputCols, values)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+ /** @group setParam */
+ @Since("4.0.0")
+ def setSmoothing(value: Double): this.type = set(smoothing, value)
+
+ @Since("4.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ if (outputFeatures.length == stats.length) {
+ outputFeatures.filter(_ != null)
+ .foldLeft(validateSchema(schema, fitting = false)) {
+ case (newSchema, outputField) =>
+ newSchema.add(StructField(outputField, DoubleType, nullable =
false))
+ }
+ } else throw new SparkException("The number of features does not match the
number of " +
+ s"encodings in the model (${stats.length}). " +
+ s"Found ${outputFeatures.length} features)")
+ }
+
+ @Since("4.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema)
+
+ // encodings: Array[Map[category, encoding]]
+ val encodings: Array[Map[Double, Double]] =
+ stats.map {
+ stat =>
+ val (global_count, global_stat) =
stat.get(TargetEncoder.UNSEEN_CATEGORY).get
+ stat.map {
+ case (cat, (class_count, class_stat)) => cat -> {
+ val weight = class_count / (class_count + $(smoothing)) //
smoothing weight
+ $(targetType) match {
+ case TargetEncoder.TARGET_BINARY =>
+ // calculate conditional probabilities and blend
+ weight * (class_stat/ class_count) + (1 - weight) *
(global_stat / global_count)
+ case TargetEncoder.TARGET_CONTINUOUS =>
+ // blend means
+ weight * class_stat + (1 - weight) * global_stat
+ }
+ }
+ }
+ }
+
+ // builds a column-to-column function from a map of encodings
+ val apply_encodings: Map[Double, Double] => (Column => Column) =
+ (mappings: Map[Double, Double]) => {
+ (col: Column) => {
+ val nullWhen = when(col.isNull,
+ mappings.get(TargetEncoder.NULL_CATEGORY) match {
+ case Some(code) => lit(code)
+ case None => if ($(handleInvalid) == TargetEncoder.KEEP_INVALID)
{
+ lit(mappings.get(TargetEncoder.UNSEEN_CATEGORY).get)
+ } else raise_error(lit(
+ s"Unseen null value in feature ${col.toString}. To handle
unseen values, " +
+ s"set Param handleInvalid to
${TargetEncoder.KEEP_INVALID}."))
+ })
+ val ordered_mappings = (mappings -
TargetEncoder.NULL_CATEGORY).toList.sortWith {
+ (a, b) =>
+ (b._1 == TargetEncoder.UNSEEN_CATEGORY) ||
+ ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1 < b._1))
+ }
+ ordered_mappings
+ .foldLeft(nullWhen)(
+ (new_col: Column, mapping) => {
+ val (original, encoded) = mapping
+ if (original != TargetEncoder.UNSEEN_CATEGORY) {
+ new_col.when(col === original, lit(encoded))
+ } else { // unseen category
+ new_col.otherwise(
+ if ($(handleInvalid) == TargetEncoder.KEEP_INVALID)
lit(encoded)
+ else raise_error(concat(
+ lit("Unseen value "), col,
+ lit(s" in feature ${col.toString}. To handle unseen
values, " +
+ s"set Param handleInvalid to
${TargetEncoder.KEEP_INVALID}."))))
+ }
+ })
+ }
+ }
+
+ dataset.withColumns(
+ inputFeatures.zip(outputFeatures).zip(encodings)
+ .map {
+ case ((featureIn, featureOut), mapping) =>
+ featureOut ->
+ apply_encodings(mapping)(col(featureIn))
+ .as(featureOut, NominalAttribute.defaultAttr
+ .withName(featureOut)
+ .withNumValues(mapping.values.toSet.size)
+
.withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata())
+ }.toMap)
+
+ }
+
+ @Since("4.0.0")
+ override def copy(extra: ParamMap): TargetEncoderModel = {
+ val copied = new TargetEncoderModel(uid, stats)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ @Since("4.0.0")
+ override def write: MLWriter = new
TargetEncoderModel.TargetEncoderModelWriter(this)
+
+ @Since("4.0.0")
+ override def toString: String = {
+ s"TargetEncoderModel: uid=$uid, " +
+ s"handleInvalid=${$(handleInvalid)}, targetType=${$(targetType)}, " +
+ s"numInputCols=${inputFeatures.length},
numOutputCols=${outputFeatures.length}, " +
+ s"smoothing=${$(smoothing)}"
+ }
+
+}
+
+@Since("4.0.0")
+object TargetEncoderModel extends MLReadable[TargetEncoderModel] {
+
+ private[TargetEncoderModel]
+ class TargetEncoderModelWriter(instance: TargetEncoderModel) extends
MLWriter {
+
+ private case class Data(stats: Array[Map[Double, (Double, Double)]])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
+ val data = Data(instance.stats)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
+ }
+ }
+
+ private class TargetEncoderModelReader extends MLReader[TargetEncoderModel] {
+
+ private val className = classOf[TargetEncoderModel].getName
+
+ override def load(path: String): TargetEncoderModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.parquet(dataPath)
+ .select("encodings")
+ .head()
+ val stats = data.getAs[Array[Map[Double, (Double, Double)]]](0)
+ val model = new TargetEncoderModel(metadata.uid, stats)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ @Since("4.0.0")
+ override def read: MLReader[TargetEncoderModel] = new
TargetEncoderModelReader
+
+ @Since("4.0.0")
+ override def load(path: String): TargetEncoderModel = super.load(path)
+}
+
diff --git
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
new file mode 100644
index 000000000000..c488cc0dfca1
--- /dev/null
+++
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import org.apache.spark.SharedSparkSession;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaTargetEncoderSuite extends SharedSparkSession {
+
+ @Test
+ public void testTargetEncoderBinary() {
+
+ // checkstyle.off: LineLength
+ List<Row> data = Arrays.asList(
+ RowFactory.create((short) 0, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3,
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (1 - 5.0 / 6) * (4.0 / 9),
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)),
+ RowFactory.create((short) 1, 4, 5.0, 1.0, 2.0 / 3, 1.0, 1.0 / 3,
+ (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9),
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)),
+ RowFactory.create((short) 2, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3,
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (1 - 5.0 / 6) * (4.0 / 9),
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)),
+ RowFactory.create((short) 0, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3,
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9),
+ (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)),
+ RowFactory.create((short) 1, 3, 6.0, 0.0, 2.0 / 3, 0.0, 2.0 / 3,
+ (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (1 - 5.0 / 6) * (4.0 / 9),
+ (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)),
+ RowFactory.create((short) 2, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3,
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9),
+ (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)),
+ RowFactory.create((short) 0, 3, 7.0, 0.0, 1.0 / 3, 0.0, 0.0,
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (1 - 5.0 / 6) * (4.0 / 9), (1 - 1.0 / 2) * (4.0 / 9)),
+ RowFactory.create((short) 1, 4, 8.0, 1.0, 2.0 / 3, 1.0, 1.0,
+ (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9),
+ (1.0 / 2) + (1 - 1.0 / 2) * (4.0 / 9)),
+ RowFactory.create((short) 2, 3, null, 0.0, 1.0 / 3, 0.0, 0.0,
+ (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9),
+ (1 - 5.0 / 6) * (4.0 / 9),
+ (1 - 1.0 / 2) * (4.0 / 9)));
+ // checkstyle.off: LineLength
+ StructType schema = createStructType(new StructField[]{
+ createStructField("input1", ShortType, true),
+ createStructField("input2", IntegerType, true),
+ createStructField("input3", DoubleType, true),
+ createStructField("label", DoubleType, false),
+ createStructField("expected1", DoubleType, false),
+ createStructField("expected2", DoubleType, false),
+ createStructField("expected3", DoubleType, false),
+ createStructField("smoothing1", DoubleType, false),
+ createStructField("smoothing2", DoubleType, false),
+ createStructField("smoothing3", DoubleType, false)
+ });
+
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
+
+ TargetEncoder encoder = new TargetEncoder()
+ .setInputCols(new String[]{"input1", "input2", "input3"})
+ .setOutputCols(new String[]{"output1", "output2", "output3"})
+ .setTargetType("binary");
+ TargetEncoderModel model = encoder.fit(dataset);
+
+ Dataset<Row> output = model.transform(dataset);
+ Assertions.assertEquals(
+ output.select("output1", "output2", "output3").collectAsList(),
+ output.select("expected1", "expected2", "expected3").collectAsList());
+
+ Dataset<Row> output_smoothing = model.setSmoothing(1.0).transform(dataset);
+ Assertions.assertEquals(
+ output_smoothing.select("output1", "output2", "output3").collectAsList(),
+ output_smoothing.select("smoothing1", "smoothing2",
"smoothing3").collectAsList());
+
+ }
+
+ @Test
+ public void testTargetEncoderContinuous() {
+
+ List<Row> data = Arrays.asList(
+ RowFactory.create((short) 0, 3, 5.0, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0,
27.5),
+ RowFactory.create((short) 1, 4, 5.0, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0,
27.5),
+ RowFactory.create((short) 2, 3, 5.0, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0,
27.5),
+ RowFactory.create((short) 0, 4, 6.0, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0,
50.0),
+ RowFactory.create((short) 1, 3, 6.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0,
50.0),
+ RowFactory.create((short) 2, 4, 6.0, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0,
50.0),
+ RowFactory.create((short) 0, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0,
60.0),
+ RowFactory.create((short) 1, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0,
65.0),
+ RowFactory.create((short) 2, 3, null, 90.0, 60.0, 50.0, 90.0, 57.5,
50.0, 70.0));
+
+ StructType schema = createStructType(new StructField[]{
+ createStructField("input1", ShortType, true),
+ createStructField("input2", IntegerType, true),
+ createStructField("input3", DoubleType, true),
+ createStructField("label", DoubleType, false),
+ createStructField("expected1", DoubleType, false),
+ createStructField("expected2", DoubleType, false),
+ createStructField("expected3", DoubleType, false),
+ createStructField("smoothing1", DoubleType, false),
+ createStructField("smoothing2", DoubleType, false),
+ createStructField("smoothing3", DoubleType, false)
+ });
+
+ Dataset<Row> dataset = spark.createDataFrame(data, schema);
+
+ TargetEncoder encoder = new TargetEncoder()
+ .setInputCols(new String[]{"input1", "input2", "input3"})
+ .setOutputCols(new String[]{"output1", "output2", "output3"})
+ .setTargetType("continuous");
+ TargetEncoderModel model = encoder.fit(dataset);
+
+ Dataset<Row> output = model.transform(dataset);
+ Assertions.assertEquals(
+ output.select("output1", "output2", "output3").collectAsList(),
+ output.select("expected1", "expected2", "expected3").collectAsList());
+
+ Dataset<Row> output_smoothing = model.setSmoothing(1.0).transform(dataset);
+ Assertions.assertEquals(
+ output_smoothing.select("output1", "output2", "output3").collectAsList(),
+ output_smoothing.select("smoothing1", "smoothing2",
"smoothing3").collectAsList());
+
+ }
+
+}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
new file mode 100644
index 000000000000..869be94ff127
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala
@@ -0,0 +1,529 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+class TargetEncoderSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ @transient var data_binary: Seq[Row] = _
+ @transient var data_continuous: Seq[Row] = _
+ @transient var schema: StructType = _
+ @transient var expected_stats_binary: Array[Map[Double, (Double, Double)]] =
_
+ @transient var expected_stats_continuous: Array[Map[Double, (Double,
Double)]] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ // scalastyle:off
+ data_binary = Seq(
+ Row(0.toShort, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3,
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9),
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)),
+ Row(1.toShort, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3,
(3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9),
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)),
+ Row(2.toShort, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3,
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9),
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)),
+ Row(0.toShort, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3,
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9),
(3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)),
+ Row(1.toShort, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3,
(3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9),
(3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)),
+ Row(2.toShort, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3,
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9),
(3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)),
+ Row(0.toShort, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0,
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9),
(1-1.0/2)*(4.0/9)),
+ Row(1.toShort, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0,
(3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9), (1.0/2)
+(1-1.0/2)*(4.0/9)),
+ Row(2.toShort, 3, 9.0, 0.0, 1.0/3, 0.0, 0.0,
(3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9),
(1-1.0/2)*(4.0/9)))
+
+ data_continuous = Seq(
+ Row(0.toShort, 3, 5.0, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5),
+ Row(1.toShort, 4, 5.0, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5),
+ Row(2.toShort, 3, 5.0, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5),
+ Row(0.toShort, 4, 6.0, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0),
+ Row(1.toShort, 3, 6.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0),
+ Row(2.toShort, 4, 6.0, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0),
+ Row(0.toShort, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0),
+ Row(1.toShort, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0),
+ Row(2.toShort, 3, 9.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0))
+
+ schema = StructType(Array(
+ StructField("input1", ShortType, nullable = true),
+ StructField("input2", IntegerType, nullable = true),
+ StructField("input3", DoubleType, nullable = true),
+ StructField("label", DoubleType),
+ StructField("expected1", DoubleType),
+ StructField("expected2", DoubleType),
+ StructField("expected3", DoubleType),
+ StructField("smoothing1", DoubleType),
+ StructField("smoothing2", DoubleType),
+ StructField("smoothing3", DoubleType)))
+
+ expected_stats_binary = Array(
+ Map(0.0 -> (3.0, 1.0), 1.0 -> (3.0, 2.0), 2.0 -> (3.0, 1.0),
TargetEncoder.UNSEEN_CATEGORY -> (9.0, 4.0)),
+ Map(3.0 -> (5.0, 0.0), 4.0 -> (4.0, 4.0), TargetEncoder.UNSEEN_CATEGORY
-> (9.0, 4.0)),
+ HashMap(5.0 -> (3.0, 1.0), 6.0 -> (3.0, 2.0), 7.0 -> (1.0, 0.0), 8.0 ->
(1.0, 1.0), 9.0 -> (1.0, 0.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 4.0)))
+
+ expected_stats_continuous = Array(
+ Map(0.0 -> (3.0, 40.0), 1.0 -> (3.0, 50.0), 2.0 -> (3.0, 60.0),
TargetEncoder.UNSEEN_CATEGORY -> (9.0, 50.0)),
+ Map(3.0 -> (5.0, 50.0), 4.0 -> (4.0, 50.0),
TargetEncoder.UNSEEN_CATEGORY -> (9.0, 50.0)),
+ HashMap(5.0 -> (3.0, 20.0), 6.0 -> (3.0, 50.0), 7.0 -> (1.0, 70.0), 8.0
-> (1.0, 80.0), 9.0 -> (1.0, 90.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0,
50.0)))
+ // scalastyle:on
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new TargetEncoder)
+ }
+
+ test("TargetEncoder - binary target") {
+
+ val df = spark.createDataFrame(sc.parallelize(data_binary), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df)
+
+ model.stats.zip(expected_stats_binary).foreach{
+ case (actual, expected) => assert(actual.equals(expected))
+ }
+
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df.select("input1", "input2", "input3",
+ "expected1", "expected2", "expected3"),
+ model,
+ "output1", "expected1",
+ "output2", "expected2",
+ "output3", "expected3") {
+ case Row(output1: Double, expected1: Double,
+ output2: Double, expected2: Double,
+ output3: Double, expected3: Double) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
+ assert(output3 === expected3)
+ }
+
+ val model_smooth = model.setSmoothing(1.0)
+
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df.select("input1", "input2", "input3",
+ "smoothing1", "smoothing2", "smoothing3"),
+ model_smooth,
+ "output1", "smoothing1",
+ "output2", "smoothing2",
+ "output3", "smoothing3") {
+ case Row(output1: Double, expected1: Double,
+ output2: Double, expected2: Double,
+ output3: Double, expected3: Double) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
+ assert(output3 === expected3)
+ }
+
+
+ }
+
+ test("TargetEncoder - continuous target") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_continuous), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df)
+
+ model.stats.zip(expected_stats_continuous).foreach{
+ case (actual, expected) => assert(actual.equals(expected))
+ }
+
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df.select("input1", "input2", "input3",
+ "expected1", "expected2", "expected3"),
+ model,
+ "output1", "expected1",
+ "output2", "expected2",
+ "output3", "expected3") {
+ case Row(output1: Double, expected1: Double,
+ output2: Double, expected2: Double,
+ output3: Double, expected3: Double) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
+ assert(output3 === expected3)
+ }
+
+ val model_smooth = model.setSmoothing(1.0)
+
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df.select("input1", "input2", "input3",
+ "smoothing1", "smoothing2", "smoothing3"),
+ model_smooth,
+ "output1", "smoothing1",
+ "output2", "smoothing2",
+ "output3", "smoothing3") {
+ case Row(output1: Double, expected1: Double,
+ output2: Double, expected2: Double,
+ output3: Double, expected3: Double) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
+ assert(output3 === expected3)
+ }
+
+ }
+
+ test("TargetEncoder - unseen value - keep") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_continuous), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setHandleInvalid(TargetEncoder.KEEP_INVALID)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df)
+
+ val data_unseen = Row(0.toShort, 3, 10.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0,
0.0)
+
+ val df_unseen = spark
+ .createDataFrame(sc.parallelize(data_continuous :+ data_unseen), schema)
+
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df_unseen.select("input1", "input2", "input3",
+ "expected1", "expected2", "expected3"),
+ model,
+ "output1", "expected1",
+ "output2", "expected2",
+ "output3", "expected3") {
+ case Row(output1: Double, expected1: Double,
+ output2: Double, expected2: Double,
+ output3: Double, expected3: Double) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
+ assert(output3 === expected3)
+ }
+ }
+
+ test("TargetEncoder - unseen value - error") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_continuous), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setHandleInvalid(TargetEncoder.ERROR_INVALID)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df)
+
+ val data_unseen = Row(0.toShort, 3, 10.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0,
0.0, 0.0)
+
+ val df_unseen = spark
+ .createDataFrame(sc.parallelize(data_continuous :+ data_unseen), schema)
+
+ val ex = intercept[SparkRuntimeException] {
+ val output = model.transform(df_unseen)
+ output.show()
+ }
+
+ assert(ex.isInstanceOf[SparkRuntimeException])
+ assert(ex.getMessage.contains("Unseen value 10.0 in feature input3"))
+
+ }
+
+ test("TargetEncoder - missing feature") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_binary), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val ex = intercept[SparkException] {
+ val model = encoder.fit(df.drop("input3"))
+ print(model.stats)
+ }
+
+ assert(ex.isInstanceOf[SparkException])
+ assert(ex.getMessage.contains("No column named input3 found on dataset"))
+ }
+
+ test("TargetEncoder - wrong data type") {
+
+ val wrong_schema = new StructType(
+ schema.map{
+ field: StructField => if (field.name != "input3") field
+ else StructField(field.name, StringType, field.nullable,
field.metadata)
+ }.toArray)
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_binary), wrong_schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val ex = intercept[SparkException] {
+ val model = encoder.fit(df)
+ print(model.stats)
+ }
+
+ assert(ex.isInstanceOf[SparkException])
+ assert(ex.getMessage.contains("Data type for column input3 is StringType"))
+ }
+
+ test("TargetEncoder - seen null category") {
+
+ val data_null = Row(2.toShort, 3, null, 90.0, 60.0, 50.0, 90.0, 57.5,
50.0, 70.0)
+
+ val df_null = spark
+ .createDataFrame(sc.parallelize(data_continuous.dropRight(1) :+
data_null), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df_null)
+
+ val expected_stats = Array(
+ expected_stats_continuous(0),
+ expected_stats_continuous(1),
+ expected_stats_continuous(2) + (TargetEncoder.NULL_CATEGORY -> (1.0,
90.0)) - 9.0)
+
+ model.stats.zip(expected_stats).foreach{
+ case (actual, expected) => assert(actual.equals(expected))
+ }
+
+ val output = model.transform(df_null)
+
+ assert_true(
+ output("output1") === output("expected1") &&
+ output("output2") === output("expected2") &&
+ output("output3") === output("expected3"))
+
+ }
+
+ test("TargetEncoder - unseen null category") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_continuous), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setHandleInvalid(TargetEncoder.KEEP_INVALID)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val data_null = Row(null, null, null, 90.0, 50.0, 50.0, 50.0, 57.5, 50.0,
70.0)
+
+ val df_null = spark
+ .createDataFrame(sc.parallelize(data_continuous :+ data_null), schema)
+
+ val model = encoder.fit(df)
+
+ val output = model.transform(df_null)
+
+ assert_true(
+ output("output1") === output("expected1") &&
+ output("output2") === output("expected2") &&
+ output("output3") === output("expected3"))
+
+ }
+
+ test("TargetEncoder - non-indexed categories") {
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val data_noindex = Row(0.toShort, 3, 5.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0)
+
+ val df_noindex = spark
+ .createDataFrame(sc.parallelize(data_binary :+ data_noindex), schema)
+
+ val ex = intercept[SparkException] {
+ val model = encoder.fit(df_noindex)
+ print(model.stats)
+ }
+
+ assert(ex.isInstanceOf[SparkException])
+ assert(ex.getMessage.contains(
+ "Values from column input3 must be indices, but got 5.1"))
+
+ }
+
+ test("TargetEncoder - invalid label") {
+
+ val data_null = Row(2.toShort, 3, 5.0, null, 160.0, 150.0, 190.0, 57.5,
50.0, 70.0)
+ val data_nan = Row(1.toShort, 2, 6.0, Double.NaN, 160.0, 150.0, 190.0,
57.5, 50.0, 70.0)
+
+ val df_nolabel = spark
+ .createDataFrame(sc.parallelize(
+ data_continuous :+ data_null :+ data_nan), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df_nolabel)
+
+ model.stats.zip(expected_stats_continuous).foreach{
+ case (actual, expected) => assert(actual.equals(expected))
+ }
+
+ }
+
+ test("TargetEncoder - non-binary labels") {
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val data_non_binary = Row(0.toShort, 3, 5.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0)
+
+ val df_non_binary = spark
+ .createDataFrame(sc.parallelize(data_binary :+ data_non_binary), schema)
+
+ val ex = intercept[SparkException] {
+ val model = encoder.fit(df_non_binary)
+ print(model.stats)
+ }
+
+ assert(ex.isInstanceOf[SparkException])
+ assert(ex.getMessage.contains(
+ "Values from column label must be binary (0,1) but got 2.0"))
+
+ }
+
+ test("TargetEncoder - features renamed") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_continuous), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df)
+ .setInputCols(Array("renamed_input1", "renamed_input2",
"renamed_input3"))
+ .setOutputCols(Array("renamed_output1", "renamed_output2",
"renamed_output3"))
+
+ val df_renamed = df
+ .withColumnsRenamed((1 to 3).map{
+ f => s"input${f}" -> s"renamed_input${f}"}.toMap)
+
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df_renamed
+ .select("renamed_input1", "renamed_input2", "renamed_input3",
+ "expected1", "expected2", "expected3"),
+ model,
+ "renamed_output1", "expected1",
+ "renamed_output2", "expected2",
+ "renamed_output3", "expected3") {
+ case Row(output1: Double, expected1: Double,
+ output2: Double, expected2: Double,
+ output3: Double, expected3: Double) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
+ assert(output3 === expected3)
+ }
+
+ }
+
+ test("TargetEncoder - wrong number of features") {
+
+ val df = spark
+ .createDataFrame(sc.parallelize(data_binary), schema)
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("label")
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+
+ val model = encoder.fit(df)
+
+ val ex = intercept[SparkException] {
+ val output = model
+ .setInputCols(Array("input1", "input2"))
+ .setOutputCols(Array("output1", "output2"))
+ .transform(df)
+ output.show()
+ }
+
+ assert(ex.isInstanceOf[SparkException])
+ assert(ex.getMessage.contains(
+ "does not match the number of encodings in the model (3). Found 2
features"))
+
+ }
+
+ test("TargetEncoder - R/W single-column") {
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("continuousLabel")
+ .setTargetType(TargetEncoder.TARGET_CONTINUOUS)
+ .setInputCol("input1")
+ .setOutputCol("output1")
+ .setHandleInvalid(TargetEncoder.ERROR_INVALID)
+ .setSmoothing(2)
+
+ testDefaultReadWrite(encoder)
+
+ }
+
+ test("TargetEncoder - R/W multi-column") {
+
+ val encoder = new TargetEncoder()
+ .setLabelCol("binaryLabel")
+ .setTargetType(TargetEncoder.TARGET_BINARY)
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("output1", "output2", "output3"))
+ .setHandleInvalid(TargetEncoder.KEEP_INVALID)
+ .setSmoothing(1)
+
+ testDefaultReadWrite(encoder)
+
+ }
+
+}
\ No newline at end of file
diff --git a/python/docs/source/reference/pyspark.ml.rst
b/python/docs/source/reference/pyspark.ml.rst
index 965cbe7eb5a5..f81498d3b5ea 100644
--- a/python/docs/source/reference/pyspark.ml.rst
+++ b/python/docs/source/reference/pyspark.ml.rst
@@ -104,6 +104,8 @@ Feature
StopWordsRemover
StringIndexer
StringIndexerModel
+ TargetEncoder
+ TargetEncoderModel
Tokenizer
UnivariateFeatureSelector
UnivariateFeatureSelectorModel
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 9a392c9dd420..e053ea273140 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -104,6 +104,8 @@ __all__ = [
"StopWordsRemover",
"StringIndexer",
"StringIndexerModel",
+ "TargetEncoder",
+ "TargetEncoderModel",
"Tokenizer",
"UnivariateFeatureSelector",
"UnivariateFeatureSelectorModel",
@@ -5200,6 +5202,305 @@ class StopWordsRemover(
return list(stopWordsObj.loadDefaultStopWords(language))
+class _TargetEncoderParams(
+ HasLabelCol, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols,
HasHandleInvalid
+):
+ """
+ Params for :py:class:`TargetEncoder` and :py:class:`TargetEncoderModel`.
+
+ .. versionadded:: 4.0.0
+ """
+
+ handleInvalid: Param[str] = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "How to handle invalid data during transform(). "
+ + "Options are 'keep' (invalid data presented as an extra "
+ + "categorical feature) or error (throw an error).",
+ typeConverter=TypeConverters.toString,
+ )
+
+ targetType: Param[str] = Param(
+ Params._dummy(),
+ "targetType",
+ "whether the label is 'binary' or 'continuous'",
+ typeConverter=TypeConverters.toString,
+ )
+
+ smoothing: Param[float] = Param(
+ Params._dummy(),
+ "smoothing",
+ "value to smooth in-category averages with overall averages.",
+ typeConverter=TypeConverters.toFloat,
+ )
+
+ def __init__(self, *args: Any):
+ super(_TargetEncoderParams, self).__init__(*args)
+ self._setDefault(handleInvalid="error", targetType="binary",
smoothing=0.0)
+
+ @since("4.0.0")
+ def getTargetType(self) -> str:
+ """
+ Gets the value of targetType or its default value.
+ """
+ return self.getOrDefault(self.targetType)
+
+ @since("4.0.0")
+ def getSmoothing(self) -> float:
+ """
+ Gets the value of smoothing or its default value.
+ """
+ return self.getOrDefault(self.smoothing)
+
+
+@inherit_doc
+class TargetEncoder(
+ JavaEstimator["TargetEncoderModel"],
+ _TargetEncoderParams,
+ JavaMLReadable["TargetEncoder"],
+ JavaMLWritable,
+):
+ """
+ Target Encoding maps a column of categorical indices into a numerical
feature derived
+ from the target.
+
+ When :py:attr:`handleInvalid` is configured to 'keep', previously unseen
values of
+ a feature are mapped to the dataset overall statistics.
+
+ When :py:attr:'targetType' is configured to 'binary', categories are
encoded as the
+ conditional probability of the target given that category (bin counting).
+ When :py:attr:'targetType' is configured to 'continuous', categories are
encoded as
+ the average of the target given that category (mean encoding)
+
+ Parameter :py:attr:'smoothing' controls how in-category stats and overall
stats are
+ weighted to build the encodings
+
+ @note When encoding multi-column by using `inputCols` and `outputCols`
params,
+ input/output cols come in pairs, specified by the order in the arrays, and
each pair
+ is treated independently.
+
+ .. versionadded:: 4.0.0
+ """
+
+ _input_kwargs: Dict[str, Any]
+
+ @overload
+ def __init__(
+ self,
+ *,
+ inputCols: Optional[List[str]] = ...,
+ outputCols: Optional[List[str]] = ...,
+ labelCol: str = ...,
+ handleInvalid: str = ...,
+ targetType: str = ...,
+ smoothing: float = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ labelCol: str = ...,
+ handleInvalid: str = ...,
+ targetType: str = ...,
+ smoothing: float = ...,
+ inputCol: Optional[str] = ...,
+ outputCol: Optional[str] = ...,
+ ):
+ ...
+
+ @keyword_only
+ def __init__(
+ self,
+ *,
+ inputCols: Optional[List[str]] = None,
+ outputCols: Optional[List[str]] = None,
+ labelCol: str = "label",
+ handleInvalid: str = "error",
+ targetType: str = "binary",
+ smoothing: float = 0.0,
+ inputCol: Optional[str] = None,
+ outputCol: Optional[str] = None,
+ ):
+ """
+ __init__(self, \\*, inputCols=None, outputCols=None,
handleInvalid="error", dropLast=True, \
+ targetType="binary", smoothing=0.0, inputCol=None,
outputCol=None)
+ """
+ super(TargetEncoder, self).__init__()
+ self._java_obj =
self._new_java_obj("org.apache.spark.ml.feature.TargetEncoder", self.uid)
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @overload
+ def setParams(
+ self,
+ *,
+ inputCols: Optional[List[str]] = ...,
+ outputCols: Optional[List[str]] = ...,
+ labelCol: str = ...,
+ handleInvalid: str = ...,
+ targetType: str = ...,
+ smoothing: float = ...,
+ ) -> "TargetEncoder":
+ ...
+
+ @overload
+ def setParams(
+ self,
+ *,
+ labelCol: str = ...,
+ handleInvalid: str = ...,
+ targetType: str = ...,
+ smoothing: float = ...,
+ inputCol: Optional[str] = ...,
+ outputCol: Optional[str] = ...,
+ ) -> "TargetEncoder":
+ ...
+
+ @keyword_only
+ @since("4.0.0")
+ def setParams(
+ self,
+ *,
+ inputCols: Optional[List[str]] = None,
+ outputCols: Optional[List[str]] = None,
+ labelCol: str = "label",
+ handleInvalid: str = "error",
+ targetType: str = "binary",
+ smoothing: float = 0.0,
+ inputCol: Optional[str] = None,
+ outputCol: Optional[str] = None,
+ ) -> "TargetEncoder":
+ """
+ setParams(self, \\*, inputCols=None, outputCols=None,
handleInvalid="error", \
+ dropLast=True, inputCol=None, outputCol=None)
+ Sets params for this TargetEncoder.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("4.0.0")
+ def setLabelCol(self, value: str) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ @since("4.0.0")
+ def setInputCols(self, value: List[str]) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("4.0.0")
+ def setOutputCols(self, value: List[str]) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ @since("4.0.0")
+ def setInputCol(self, value: str) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("4.0.0")
+ def setOutputCol(self, value: str) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("4.0.0")
+ def setHandleInvalid(self, value: str) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
+ @since("4.0.0")
+ def setTargetType(self, value: str) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`targetType`.
+ """
+ return self._set(targetType=value)
+
+ @since("4.0.0")
+ def setSmoothing(self, value: float) -> "TargetEncoder":
+ """
+ Sets the value of :py:attr:`smoothing`.
+ """
+ return self._set(smoothing=value)
+
+ def _create_model(self, java_model: "JavaObject") -> "TargetEncoderModel":
+ return TargetEncoderModel(java_model)
+
+
+class TargetEncoderModel(
+ JavaModel, _TargetEncoderParams, JavaMLReadable["TargetEncoderModel"],
JavaMLWritable
+):
+ """
+ Model fitted by :py:class:`TargetEncoder`.
+
+ .. versionadded:: 4.0.0
+ """
+
+ @since("4.0.0")
+ def setInputCols(self, value: List[str]) -> "TargetEncoderModel":
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("4.0.0")
+ def setOutputCols(self, value: List[str]) -> "TargetEncoderModel":
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ @since("4.0.0")
+ def setInputCol(self, value: str) -> "TargetEncoderModel":
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("4.0.0")
+ def setOutputCol(self, value: str) -> "TargetEncoderModel":
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("4.0.0")
+ def setHandleInvalid(self, value: str) -> "TargetEncoderModel":
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
+ @since("4.0.0")
+ def setSmoothing(self, value: float) -> "TargetEncoderModel":
+ """
+ Sets the value of :py:attr:`smoothing`.
+ """
+ return self._set(smoothing=value)
+
+ @property
+ @since("4.0.0")
+ def stats(self) -> List[Dict[float, Tuple[float, float]]]:
+ """
+ Fitted statistics for each feature to being encoded.
+ The list contains a dictionary for each input column.
+ """
+ return self._call_java("stats")
+
+
@inherit_doc
class Tokenizer(
JavaTransformer,
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index 4bf6641723da..92919adecd06 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -29,6 +29,7 @@ from pyspark.ml.feature import (
StopWordsRemover,
StringIndexer,
StringIndexerModel,
+ TargetEncoder,
VectorSizeHint,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
@@ -346,6 +347,171 @@ class FeatureTests(SparkSessionTestCase):
)
self.assertEqual(len(transformed_list), 5)
+ def test_target_encoder_binary(self):
+ df = self.spark.createDataFrame(
+ [
+ (0, 3, 5.0, 0.0),
+ (1, 4, 5.0, 1.0),
+ (2, 3, 5.0, 0.0),
+ (0, 4, 6.0, 1.0),
+ (1, 3, 6.0, 0.0),
+ (2, 4, 6.0, 1.0),
+ (0, 3, 7.0, 0.0),
+ (1, 4, 8.0, 1.0),
+ (2, 3, 9.0, 0.0),
+ ],
+ schema="input1 short, input2 int, input3 double, label double",
+ )
+ encoder = TargetEncoder(
+ inputCols=["input1", "input2", "input3"],
+ outputCols=["output", "output2", "output3"],
+ labelCol="label",
+ targetType="binary",
+ )
+ model = encoder.fit(df)
+ te = model.transform(df)
+ actual = te.drop("label").collect()
+ expected = [
+ Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0,
output3=1.0 / 3),
+ Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0,
output3=1.0 / 3),
+ Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0,
output3=1.0 / 3),
+ Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0,
output3=2.0 / 3),
+ Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0,
output3=2.0 / 3),
+ Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0,
output3=2.0 / 3),
+ Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0,
output3=0.0),
+ Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0,
output3=1.0),
+ Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0,
output3=0.0),
+ ]
+ self.assertEqual(actual, expected)
+ te = model.setSmoothing(1.0).transform(df)
+ actual = te.drop("label").collect()
+ expected = [
+ Row(
+ input1=0,
+ input2=3,
+ input3=5.0,
+ output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(1 - 5 / 6) * (4 / 9),
+ output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ ),
+ Row(
+ input1=1,
+ input2=4,
+ input3=5.0,
+ output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
+ output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ ),
+ Row(
+ input1=2,
+ input2=3,
+ input3=5.0,
+ output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(1 - 5 / 6) * (4 / 9),
+ output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ ),
+ Row(
+ input1=0,
+ input2=4,
+ input3=6.0,
+ output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
+ output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
+ ),
+ Row(
+ input1=1,
+ input2=3,
+ input3=6.0,
+ output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(1 - 5 / 6) * (4 / 9),
+ output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
+ ),
+ Row(
+ input1=2,
+ input2=4,
+ input3=6.0,
+ output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
+ output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
+ ),
+ Row(
+ input1=0,
+ input2=3,
+ input3=7.0,
+ output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(1 - 5 / 6) * (4 / 9),
+ output3=(1 - 1 / 2) * (4 / 9),
+ ),
+ Row(
+ input1=1,
+ input2=4,
+ input3=8.0,
+ output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9),
+ output3=(1 / 2) + (1 - 1 / 2) * (4 / 9),
+ ),
+ Row(
+ input1=2,
+ input2=3,
+ input3=9.0,
+ output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9),
+ output2=(1 - 5 / 6) * (4 / 9),
+ output3=(1 - 1 / 2) * (4 / 9),
+ ),
+ ]
+ self.assertEqual(actual, expected)
+
+ def test_target_encoder_continuous(self):
+ df = self.spark.createDataFrame(
+ [
+ (0, 3, 5.0, 10.0),
+ (1, 4, 5.0, 20.0),
+ (2, 3, 5.0, 30.0),
+ (0, 4, 6.0, 40.0),
+ (1, 3, 6.0, 50.0),
+ (2, 4, 6.0, 60.0),
+ (0, 3, 7.0, 70.0),
+ (1, 4, 8.0, 80.0),
+ (2, 3, 9.0, 90.0),
+ ],
+ schema="input1 short, input2 int, input3 double, label double",
+ )
+ encoder = TargetEncoder(
+ inputCols=["input1", "input2", "input3"],
+ outputCols=["output", "output2", "output3"],
+ labelCol="label",
+ targetType="continuous",
+ )
+ model = encoder.fit(df)
+ te = model.transform(df)
+ actual = te.drop("label").collect()
+ expected = [
+ Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0,
output3=20.0),
+ Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0,
output3=20.0),
+ Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0,
output3=20.0),
+ Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0,
output3=50.0),
+ Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0,
output3=50.0),
+ Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0,
output3=50.0),
+ Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0,
output3=70.0),
+ Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0,
output3=80.0),
+ Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0,
output3=90.0),
+ ]
+ self.assertEqual(actual, expected)
+ te = model.setSmoothing(1.0).transform(df)
+ actual = te.drop("label").collect()
+ expected = [
+ Row(input1=0, input2=3, input3=5.0, output1=42.5, output2=50.0,
output3=27.5),
+ Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0,
output3=27.5),
+ Row(input1=2, input2=3, input3=5.0, output1=57.5, output2=50.0,
output3=27.5),
+ Row(input1=0, input2=4, input3=6.0, output1=42.5, output2=50.0,
output3=50.0),
+ Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0,
output3=50.0),
+ Row(input1=2, input2=4, input3=6.0, output1=57.5, output2=50.0,
output3=50.0),
+ Row(input1=0, input2=3, input3=7.0, output1=42.5, output2=50.0,
output3=60.0),
+ Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0,
output3=65.0),
+ Row(input1=2, input2=3, input3=9.0, output1=57.5, output2=50.0,
output3=70.0),
+ ]
+ self.assertEqual(actual, expected)
+
def test_vector_size_hint(self):
df = self.spark.createDataFrame(
[
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]